Compare commits

...

703 Commits

Author SHA1 Message Date
Jack Christensen
fc334e4c75
Merge pull request #2322 from Nomlsbad/better-parse-config-errors
Use `ParseConfigError` in `pgx.ParseConfig` and `pgxpool.ParseConfig`
2025-05-18 09:03:38 -05:00
Yurasov Ilia
3f5509fe98 fix: remove fmt import from the pgxpool 2025-05-18 02:28:24 +03:00
Jack Christensen
15bca4a4e1 Release v5.7.5 2025-05-17 17:14:14 -05:00
Jack Christensen
1d557f9116 Remove PlanScan memoization
Previously, PlanScan used a cache to improve performance. However, the
cache could get confused in certain cases. For example, the following
would fail:

m := pgtype.NewMap()
var err error

var tags any
err = m.Scan(pgtype.TextArrayOID, pgx.TextFormatCode, []byte("{foo,bar,baz}"), &tags)
require.NoError(t, err)

var cells [][]string
err = m.Scan(pgtype.TextArrayOID, pgx.TextFormatCode, []byte("{{foo,bar},{baz,quz}}"), &cells)
require.NoError(t, err)

This commit removes the memoization and adds a test to ensure that this
case works.

The benchmarks were also updated to include an array of strings to
ensure this path is benchmarked. As it turned out, there was next to no
performance difference between the cached and non-cached versions.

It's possible there may be a performance impact in certain complicated
cases, but I have not encountered any. If there are any performance
issues, we can optimize the narrower case rather than adding memoization
everywhere.
2025-05-17 16:34:01 -05:00
Jack Christensen
de7fe81d78 Use reflect.TypeFor instead of reflect.TypeOf
Simplified function became available in Go 1.22.
2025-05-17 09:11:31 -05:00
Jack Christensen
d9eb089bd7 Remove unused function 2025-05-17 09:09:07 -05:00
Jack Christensen
6be24eb08d Fix comment typo 2025-05-17 09:01:55 -05:00
Jack Christensen
07871c0a34 Zero internal baseRows references to allow GC earlier
See https://github.com/jackc/pgx/pull/2269
2025-05-17 08:39:54 -05:00
Yurasov Ilia
de806a11e7 fix: return pgconn.ParseConfigError inside ParseConfig functions 2025-05-14 03:50:09 +03:00
Yurasov Ilia
ce13266e90 fix: move pgconn.NewParseConfigError to errors.go
Move pgconn.NewParseConfigError function to errors.go from
export_test.go for creatin pgconn.ParseConfigError inside the pgx and
pgxpool packages.
2025-05-14 03:43:21 +03:00
Jack Christensen
777e7e5cdf
Merge pull request #2313 from stampy88/tracelog_pool_additions
Implement AcquireTracer and ReleaseTracer for TraceLog
2025-05-10 11:30:57 -05:00
dave sinclair
151bd026ec Switched to LogLevelDebug 2025-05-10 12:28:26 -04:00
Jack Christensen
540fcaa9b9 Add support for PGOPTIONS environment variable
Match libpq behavior for PGOPTIONS environment variable. See
https://www.postgresql.org/docs/current/libpq-envars.html
2025-05-10 11:09:39 -05:00
Jack Christensen
3a248e3822 Add support for PGTZ environment variable
Match libpq behavior for PGTZ environment variable. See
https://www.postgresql.org/docs/current/libpq-envars.html
2025-05-10 10:58:35 -05:00
dave sinclair
baca2d848a Implement AcquireTracer and ReleaseTracer for TraceLog
- `TraceLog` now implements the `pgxpool.AcquireTracer` and `pgxpool.ReleaseTracer` interfaces to log connection pool interactions.
2025-05-08 16:28:21 -04:00
Jack Christensen
c911d86cff
Merge pull request #2309 from dzherb/fix_doc
chore: fix typo in doc
2025-05-03 13:46:46 -05:00
dzherb
2bac99e2ae chore: fix typo 2025-04-28 20:58:11 +03:00
Jack Christensen
c92d0a9045 Update golang.org/x/crypto to v0.37.0
This required bumping the minimum Go version to 1.23.0.
2025-04-26 10:09:29 -05:00
Jack Christensen
e9aad0fb0b Add test for tracer within transaction
https://github.com/jackc/pgx/issues/2304
2025-04-26 09:55:31 -05:00
Jack Christensen
9e7f38cd50
Merge pull request #2302 from usernameisnull/pgconn-error
chore: should be pgconn.PgError
2025-04-17 08:52:47 -05:00
bing.ma
e779a5c072 chore: should be pgconn.PgError 2025-04-17 18:30:06 +08:00
Jack Christensen
ff9c26d851 Make OpenDBFromPool docs explicit about closing the *sql.DB
https://github.com/jackc/pgx/issues/2295
2025-04-05 09:01:11 -05:00
Jack Christensen
0f77a2d028
Merge pull request #2293 from divyam234/master
feat: add support for direct sslnegotiation
2025-03-31 08:13:19 -05:00
divyam234
ddd966f09f
update 2025-03-31 15:06:55 +02:00
divyam234
924834b5b4
add pgmock tests 2025-03-31 15:02:07 +02:00
divyam234
9b15554c51
respect sslmode set by user 2025-03-30 16:35:43 +02:00
divyam234
037e4cf9a2
feat: add support for direct sslnegotiation 2025-03-30 16:21:52 +02:00
Jack Christensen
04bcc0219d Add v5.7.4 to changelog 2025-03-24 20:04:45 -05:00
Jack Christensen
0e0a7d8344
Merge pull request #2288 from felix-roehrich/fr/fix-plan-scan
Revert change in `if` from #2236.
2025-03-24 19:46:22 -05:00
Felix Röhrich
63422c7d6c revert change in if 2025-03-24 15:01:50 +01:00
Jack Christensen
5c1fbf4806 Update changelog for v5.7.3 2025-03-21 21:02:03 -05:00
Jack Christensen
05fe5f8b05 Explain seemingly redundant rows.Close() in CollectOneRow
fixes https://github.com/jackc/pgx/issues/2282
2025-03-21 20:33:32 -05:00
Jack Christensen
70c9a147a2
Merge pull request #2279 from djahandarie/min-idle-conns
Add MinIdleConns
2025-03-21 20:25:19 -05:00
Darius Jahandarie
6603ddfbe4
add MinIdleConns 2025-03-15 19:14:26 +09:00
Jack Christensen
70f7cad222 Add link to https://github.com/Arlandaren/pgxWrappy 2025-02-28 20:59:28 -06:00
Jack Christensen
6bf1b0b1b9 Add database/sql to overview of scanning 2025-02-22 08:42:26 -06:00
Jack Christensen
14bda65a0c Correct pgtype docs 2025-02-22 08:34:31 -06:00
Jack Christensen
9e3c4fb40f
Merge pull request #2257 from felix-roehrich/fr/change-connect-logic
Change connection logic to be more forgiving
2025-02-19 07:36:35 -06:00
Felix Röhrich
05e72a5ab1 make connection logic more forgiving 2025-02-17 21:24:38 +01:00
Jack Christensen
47d631e34b Added missed change to v5.7.2 changelog 2025-02-08 10:21:39 -06:00
Jack Christensen
58b05f567c Add https://github.com/nikolayk812/pgx-outbox to README.md
fixes https://github.com/jackc/pgx/issues/2239
2025-01-25 08:59:52 -06:00
Jack Christensen
dcb7193669
Merge pull request #2236 from felix-roehrich/fr/fix-plan-scan
Alternative implementation for JSONCodec.PlanScan
2025-01-25 08:56:38 -06:00
Jack Christensen
1abf7d9050
Merge pull request #2240 from bonnefoa/fix-watch-panic
Unwatch and close connection on a batch write error
2025-01-25 08:38:33 -06:00
Jack Christensen
b5efc90a32
Merge pull request #2028 from jackc/enable-composite-tests-on-cockroachdb
Enable composite tests on cockroachdb
2025-01-25 08:22:32 -06:00
Jack Christensen
a26c93551f Skip TestCompositeCodecTranscodeStructWrapperForTable 2025-01-25 08:15:40 -06:00
Jack Christensen
2100e1da46 Use latest version of CockroachDB for CI 2025-01-25 08:04:42 -06:00
Jack Christensen
2d21a2b80d
Merge pull request #2228 from jackc/fix-xml-decode-value
XMLCodec: fix DecodeValue to return a []byte
2025-01-25 07:24:30 -06:00
Jack Christensen
5f33ee5f07 Call out []byte in QueryExecModeSimpleProtocol documentation
https://github.com/jackc/pgx/issues/2231
2025-01-25 07:15:02 -06:00
Anthonin Bonnefoy
228cfffc20 Unwatch and close connection on a batch write error
Previously, a conn.Write would simply unlock pgconn, leaving the
connection as Idle and reusable while the multiResultReader would be
closed. From this state, calling multiResultReader.Close won't try to
receiveMessage and thus won't unwatch and close the connection since it
is already closed. This leaves the connection "open" and the next time
it's used, a "Watch already in progress" panic could be triggered.

This patch fixes the issue by unwatching and closing the connection on a
batch write error. The same was done on Sync.Encode error even if the
path is unreachable as Sync.Error never returns an error.
2025-01-24 08:49:07 +01:00
Felix Röhrich
a5353af354 rework JSONCodec.PlanScan 2025-01-22 22:35:35 +01:00
Jack Christensen
0bc29e3000
Merge pull request #2225 from logicbomb/improve-error-message
Include the field name in error messages when scanning structs
2025-01-18 10:41:13 -06:00
Jack Christensen
9cce05944a
Merge pull request #2216 from pconstantinou/master
Timestamp incorrectly adds 'Z' when serializing from JSON to indicate GMT, fixes bug #2215
2025-01-18 10:17:43 -06:00
Jason Turim
9c0ad690a9 Include the field name in error messages when scanning structs 2025-01-11 14:31:24 -05:00
Jack Christensen
03f08abda3 Fix in Unmarshal function rather than DecodeValue
This preserves backwards compatibility in the unlikely event someone is
using an alternative XML unmarshaler that does support unmarshalling
into *any.
2025-01-11 11:26:46 -06:00
Jack Christensen
2c1b1c389a
Merge pull request #2200 from zenkovev/flush_request_in_pipeline
add flush request in pipeline
2025-01-11 11:15:36 -06:00
Jack Christensen
329cb45913 XMLCodec: fix DecodeValue to return a []byte
Previously, DecodeValue would always return nil with the default
Unmarshal function.

fixes https://github.com/jackc/pgx/issues/2227
2025-01-11 10:55:48 -06:00
zenkovev
c96a55f8c0 private const for pipelineRequestType 2025-01-11 19:54:18 +03:00
Jack Christensen
e87760682f Update oldest supported Go version to 1.22 2025-01-11 07:49:50 -06:00
Jack Christensen
f681632c68 Drop PG 12 support and add PG 17 to CI 2025-01-11 07:49:26 -06:00
Phil Constantinou
3c640a44b6 Making the tests a little cleaner and clear 2025-01-06 09:24:55 -08:00
zenkovev
de3f868c1d pipeline queue for client requests 2025-01-06 13:54:48 +03:00
Phil Constantinou
5424d3c873 Return error and make sure they are unit tested 2025-01-05 19:45:45 -08:00
Phil Constantinou
42d3d00734 Parse as a UTC time 2025-01-05 19:19:17 -08:00
Phil Constantinou
cdc672cf3f Make JSON output confirm to ISO8601 timestamp without a timezone 2025-01-05 13:05:51 -08:00
Phil Constantinou
52e2858629 Added unit test and fixed typo 2025-01-02 13:36:33 -08:00
Phil Constantinou
e352784fed Add Z only if needed. 2025-01-02 12:50:29 -08:00
Jack Christensen
c2175fe46e
Merge pull request #2213 from moukoublen/fix_2204
Fix #2204
2024-12-30 20:35:41 -06:00
Jack Christensen
659823f8f3 Add link to github.com/amirsalarsafaei/sqlc-pgx-monitoring
fixes https://github.com/jackc/pgx/issues/2212
2024-12-30 20:27:10 -06:00
Jack Christensen
ca04098fab
Merge pull request #2136 from ninedraft/optimize-sanitize
Reduce SQL sanitizer allocations
2024-12-30 20:24:13 -06:00
Jack Christensen
4ff0a454e0
Merge pull request #2211 from EinoPlasma/master
Fixes for Method Comment and Typo in Test Function Name
2024-12-30 20:12:22 -06:00
Jack Christensen
00b86ca3db
Merge pull request #2208 from vamshiaruru/feat/expose_empty_acquire_wait_time_from_puddle
Expose puddle.Pool's EmptyAcquireWaitTime in pgxpool's Stats
2024-12-30 20:03:51 -06:00
Kostas Stamatakis
61a0227241
simplify test 2024-12-30 23:15:46 +02:00
Kostas Stamatakis
2190a8e0d1
cleanup and add test for json codec 2024-12-30 23:09:19 +02:00
Kostas Stamatakis
6e9fa42fef
fix #2204 2024-12-30 22:54:42 +02:00
EinoPlasma
6d9e6a726e Fix typo in test function name 2024-12-29 21:03:38 +08:00
EinoPlasma
02e387ea64 Fix method comment in PasswordMessage 2024-12-29 20:59:24 +08:00
merlin
e452f80b1d
TestErrNoRows: remove bad test case 2024-12-28 13:39:01 +02:00
merlin
da0315d1a4
optimisations of quote functions by @sean- 2024-12-28 13:31:09 +02:00
merlin
120c89fe0d
fix preallocations of quoted string 2024-12-28 13:31:09 +02:00
merlin
057937db27
add prefix to quoters tests 2024-12-28 13:31:09 +02:00
merlin
47cbd8edb8
drop too large values from memory pools 2024-12-28 13:31:09 +02:00
merlin
90a77b13b2
add docs to sanitize tests 2024-12-28 13:31:08 +02:00
merlin
59d6aa87b9
rework QuoteString and QuoteBytes as append-style 2024-12-28 13:31:08 +02:00
merlin
39ffc8b7a4
add lexer and query pools
use lexer pool
2024-12-28 13:31:08 +02:00
merlin
c4c1076d28
add FuzzQuoteString and FuzzQuoteBytes 2024-12-28 13:31:08 +02:00
merlin
4293b25262
decrease number of samples in go benchmark 2024-12-28 13:31:08 +02:00
merlin
ea1e13a660
quoteString 2024-12-28 13:31:08 +02:00
merlin
58d4c0c94f
quoteBytes
check new quoteBytes
2024-12-28 13:31:08 +02:00
merlin
1752f7b4c1
docs 2024-12-28 13:31:08 +02:00
merlin
ee718a110d
append AvailableBuffer 2024-12-28 13:31:08 +02:00
merlin
546ad2f4e2
shared bytestring 2024-12-28 13:31:08 +02:00
merlin
efc2c9ff44
buf pool 2024-12-28 13:31:08 +02:00
merlin
aabed18db8
add benchmark tool
fix benchmmark script

fix benchmark script
2024-12-28 13:31:08 +02:00
merlin
afa974fb05
base case
make benchmark more extensive

add quote to string

add BenchmarkSanitizeSQL
2024-12-28 13:31:08 +02:00
Vamshi Aruru
12b37f3218 Expose puddle.Pool's EmptyAcquireWaitTime in pgxpool's Stats
Addresses: https://github.com/jackc/pgx/issues/2205
2024-12-26 13:46:49 +05:30
Jack Christensen
bcf3fbd780
Merge pull request #2206 from alexandear/refactor-impossible-cond
Refactor Conn.LoadTypes by removing redundant check
2024-12-24 11:14:17 -06:00
Jack Christensen
f7c3d190ad
Merge pull request #2203 from martinyonatann/chore/check-array-and-remove-imposible-condition
check array just using `len` and remove `imposible condition`
2024-12-24 11:10:45 -06:00
Jack Christensen
473a241b96
Merge pull request #2202 from martinyonatann/chore/remove-unused-parameter
remove unused func and parameter
2024-12-24 09:32:07 -06:00
Oleksandr Redko
311f72afdc Refactor Conn.LoadTypes by removing redundant check 2024-12-24 12:58:15 +02:00
martinpasaribu
877111ceeb
check array just using len and remove imposible condition 2024-12-22 23:57:28 +07:00
martinpasaribu
dc3aea06b5
remove unused func and parameter 2024-12-22 23:48:08 +07:00
Jack Christensen
e5d321f920
Merge pull request #2197 from alexandear/fix-generated-hdr
Update comments in generated code to align with Go standards
2024-12-21 12:40:23 -06:00
Oleksandr Redko
17cd36818c Update comments in generated code to align with Go standards 2024-12-21 20:21:32 +02:00
Jack Christensen
24fbe353ed Create changelog for v5.7.2 2024-12-21 09:25:36 -06:00
Jack Christensen
3a1593b25b
Merge pull request #2198 from alexandear/fix-nilness
Handle errors  in generate_certs
2024-12-21 08:27:55 -06:00
Jack Christensen
9d851d7c98 Fix integration benchmarks 2024-12-21 08:22:12 -06:00
Jack Christensen
dacffdc7e2
Merge pull request #2196 from alexandear/docs-improve-links
Improve links in README
2024-12-21 08:13:57 -06:00
Jack Christensen
bc7c840770
Merge pull request #2195 from LucasHild/master
Add CommitQuery to transaction options
2024-12-21 08:12:58 -06:00
Oleksandr Redko
043685147f Handle errors in generate_certs 2024-12-18 02:31:56 +02:00
Oleksandr Redko
25329273da Improve links in README 2024-12-18 02:02:06 +02:00
Jack Christensen
ad87d47089
Merge pull request #2194 from alexandear/refactor/pgconn-tests
Simplify pgconn tests by using T.TempDir
2024-12-17 17:45:16 -06:00
Oleksandr Redko
7cf7bc6054 Simplify pgconn tests by using T.TempDir 2024-12-17 16:09:32 +02:00
zenkovev
76593f37f7 add flush request in pipeline 2024-12-17 11:49:13 +03:00
Jack Christensen
3e6c719698
Merge pull request #2189 from pankona/update-crypto
Update golang.org/x/crypto v0.27.0 => v0.31.0 to fix vulnerability
2024-12-13 07:54:25 -06:00
Yosuke Akatsuka
5ee33320c6 update golang.org/x/crypto v0.27.0 => v0.31.0 2024-12-12 12:58:14 +00:00
Jack Christensen
ac0b46f2f9 Warn not to create table and use it in the same batch
fixes https://github.com/jackc/pgx/issues/2182
2024-12-05 16:30:48 -06:00
Jack Christensen
e3c81cc153
Merge pull request #2169 from thedolphin/master
Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs
2024-11-28 18:09:29 -06:00
Alexander Rumyantsev
4b7e9942b2 Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions 2024-11-19 19:42:49 +03:00
Jack Christensen
b9e2b20fb1
Merge pull request #2162 from evellior/patch-1
Update pgxpool.ParseConfig documentation
2024-11-05 18:14:46 -06:00
Peyton Foley
06a0abb75e
Update pgxpool.ParseConfig documentation
Added default values and example of valid duration string to inline documentation.
2024-11-05 03:09:43 +08:00
Jack Christensen
c76a650f75 Improve documentation for QueryExecModes
https://github.com/jackc/pgx/issues/2157
2024-10-29 19:36:33 -05:00
Jack Christensen
f57b2854f8
Merge pull request #2151 from ludusrusso/fix-2146
handling double pointer on sql.Scanner interface when scanning rows
2024-10-22 18:52:46 -05:00
Ludovico Russo
5c9b565116 fix: #2146
[![Open Source Saturday](https://img.shields.io/badge/%E2%9D%A4%EF%B8%8F-open%20source%20saturday-F64060.svg)](https://www.meetup.com/it-IT/Open-Source-Saturday-Milano/)

Co-authored-by: Alessio Izzo <alessio.izzo86@gmail.com>
2024-10-19 15:43:56 +02:00
Jack Christensen
2ec900454b
Merge pull request #2145 from grachevko/string
Implement pgtype.UUID.String()
2024-10-09 08:46:03 -05:00
Konstantin Grachev
8723855d95
Implement pgtype.UUID.String() 2024-10-09 14:22:10 +03:00
Jack Christensen
3f84e891de
Merge pull request #2142 from jackc/add-xid8
Add xid8 type
2024-10-08 19:04:07 -05:00
Jack Christensen
cc05954369
Merge pull request #2138 from zenkovev/message_body_size_limit
add message body size limits in frontend and backend
2024-10-05 12:35:36 -05:00
Jack Christensen
123b59a57e Ensure planning encodes and scans cannot infinitely recurse
https://github.com/jackc/pgx/issues/2141
2024-10-05 12:20:50 -05:00
zene
10e11952bd changed style of two comments 2024-10-05 19:54:02 +03:00
Jack Christensen
32a6b1b200 Skip xid8 test on PG < 13 and CRDB 2024-10-05 10:44:13 -05:00
Jack Christensen
f0783c6fbe Add xid8 type
https://github.com/jackc/pgx/discussions/2137
2024-10-05 10:16:42 -05:00
zene
0290507ff2 remove global atomics 2024-10-04 09:26:37 +03:00
zene
8f8470edaf add message body size limits in frontend and backend 2024-09-27 15:17:47 +03:00
Jack Christensen
a95cfbb433
Merge pull request #2129 from s-montigny-desautels/fix/timestamp-json-unmarshal
Fix pgtype.Timestamp json unmarshal
2024-09-24 17:47:28 -05:00
Shean de Montigny-Desautels
7803ec3661
Fix pgtype.Timestamp json unmarshal
Add the missing 'Z' at the end of the timestamp string, so it can be
parsed as timestamp in the RFC3339 format.
2024-09-23 18:12:32 -04:00
Lucas Hild
64ca07e31b Add commit query to tx options 2024-09-23 16:46:58 +02:00
Jack Christensen
fd0c65478e Fix prepared statement already exists on batch prepare failure
When a batch successfully prepared some statements, but then failed to
prepare others, the prepared statements that were successfully prepared
were not properly cleaned up. This could lead to a "prepared statement
already exists" error on subsequent attempts to prepare the same
statement.

https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887
2024-09-13 08:03:37 -05:00
Jack Christensen
672c4a3a24 Release v5.7.1 2024-09-10 07:25:07 -05:00
Jack Christensen
f8a5a5c9e3 Update golang.org/x/crypto and golang.org/x/text 2024-09-10 07:17:03 -05:00
Jack Christensen
ab36c2c0dd Upgrade puddle to v2.2.2
This removes the import of nanotime via linkname.
2024-09-10 07:11:44 -05:00
Jack Christensen
ce66b1dae4 Fix data race with TraceLog.Config initialization
https://github.com/jackc/pgx/pull/2120
2024-09-10 07:06:39 -05:00
Jack Christensen
d1205a6dbc Release v5.7.0 2024-09-07 10:23:34 -05:00
Jack Christensen
97d20ccfad
Merge pull request #2115 from ninedraft/sql-err-no-rows
Use sql.ErrNoRows as value for pgx.ErrNoRows
2024-08-26 07:40:46 -05:00
Jack Christensen
e9bd382c51
Merge pull request #2114 from jennifersp/master
add byte length check to uint32
2024-08-26 07:28:47 -05:00
Jack Christensen
603f2337d6
Merge pull request #2113 from mateuszkowalke/master
Add comment for pgtype.Interval struct
2024-08-26 07:28:29 -05:00
merlin
035bbbe0cb
Use sql.ErrNoRows as value for pgx.ErrNoRows 2024-08-26 14:01:37 +03:00
jennifersp
73bbced270 add byte length check to uint32 2024-08-23 16:17:07 -07:00
mateuszkowalke
4171f554d4 Add additional info for nullable pgtype types
Additional information warns about using nullable types being
used as parameters to query with Valid set to false.
2024-08-23 15:15:40 +02:00
Jack Christensen
b197994b1f
Merge pull request #2112 from jennifersp/master
support text scanner for binary format for uint32
2024-08-23 07:05:15 -05:00
jennifersp
57fd684068 update struct name 2024-08-22 16:51:42 -07:00
jennifersp
926913ad66 rm bound check 2024-08-21 15:12:36 -07:00
jennifersp
b9f77cb1b3 fix typo 2024-08-21 12:27:23 -07:00
jennifersp
218c15a4eb support text scanner for binary format for uint32 2024-08-21 12:04:54 -07:00
Jack Christensen
4f7e19d67d
Merge pull request #2108 from jackc/ci-tests-go-1.23
CI tests Go 1.23
2024-08-15 18:41:48 -05:00
Jack Christensen
0cbc5db39d CI tests Go 1.23 2024-08-15 18:33:43 -05:00
Jack Christensen
5747f37d9c Fix: Scan and encode types with underlying types of arrays
Rather than special case the reported issue with UUID and [16]byte, this
commit allows the system to find the underlying type of any type that is
an array.

fixes https://github.com/jackc/pgx/issues/2107
2024-08-15 18:20:07 -05:00
Jack Christensen
d6fc8b02b4
Merge pull request #2098 from stringintech/tracelog-time-key
add TraceLogConfig with customizable TimeKey and ensureConfig method for default initialization
2024-08-13 18:39:21 -05:00
Jack Christensen
c457de62c9 Fix doc discrepancies between Tx interface and pgxpool implementation
The error is not wrapped at the moment, but document that it may be.

fixes https://github.com/jackc/pgx/issues/2104
2024-08-07 08:03:41 -05:00
stringintech
216049c62b add TraceLogConfig with customizable TimeKey and ensureConfig method for default initialization 2024-07-26 22:58:14 +03:30
Jack Christensen
a68e14fe5a Explicitly disclaim support for time with time zone 2024-07-23 17:27:05 -05:00
Jack Christensen
ea9610f672
Merge pull request #2084 from EpicStep/multiple-tracing
Implement 'MultiTracer'
2024-07-15 08:22:27 -05:00
Stepan Rabotkin
7af618e423
feat: add pool tracing to 'MultiTracer' & move it to another package 2024-07-13 17:04:04 +03:00
Stepan Rabotkin
3f270eec7d
feat: add 'MultiTracer' to go doc & cover it 2024-07-13 02:02:22 +03:00
Stepan Rabotkin
8e46d2117c
refac: export 'MultiTracer' fields 2024-07-13 01:40:46 +03:00
Jack Christensen
9530aea47b
Merge pull request #2083 from sodahealth/xml-codec
V1 XMLCodec supports encoding + scanning XML column type
2024-07-12 16:49:08 -05:00
nickcruess-soda
a8aaa37363 fix(test): skip CockroachDB since it doesn't support XML 2024-07-12 09:56:59 -05:00
Jack Christensen
67aa0e5a65
Merge pull request #2085 from nolandseigler/rows-snake-case
RowToStructByName Snake Case Collision
2024-07-12 09:00:27 -05:00
Jack Christensen
96791c88cd
Merge pull request #2082 from heavycrystal/url-parse-err-fix
don't print URL when url.Parse returns an error
2024-07-12 08:52:09 -05:00
nolandseigler
71a8e53574
use normalized equality or strict equality check in rows.go fieldPosByName 2024-07-12 08:50:54 -04:00
Kevin Biju
13e212430d
address review feedback 2024-07-12 18:11:09 +05:30
nolandseigler
b25d092d20
formatting 2024-07-11 23:30:28 -04:00
nolandseigler
7fceb64dee
in rows.go 'fieldPosByName' use boolean to replace '_' and only execution replacements when there are no db tags present 2024-07-11 23:28:21 -04:00
nolandseigler
7a35585143
example test case that demonstrates snake case collision in db tags caused by rows.go 'fieldPosByName' 2024-07-11 22:39:29 -04:00
Stepan Rabotkin
a787630988
feat: implement 'MultiTracer' 2024-07-11 23:41:17 +03:00
nickcruess-soda
37681a4f48 chore: remove unused JSONCodec code, correct typo 2024-07-11 15:18:20 -05:00
nickcruess-soda
c7b9dc0e00 feat: add pgtype.XMLCodec based on pgtype.JSONCodec 2024-07-11 15:17:55 -05:00
Kevin Biju
f007d84675
don't print url when url.Parse returns an error 2024-07-10 22:46:32 +05:30
Jack Christensen
3563a2b048
Merge pull request #2077 from nicois/nicois/pgtype-other-schema-fix
Fix bug relating to reuse of types
2024-07-08 17:29:44 -05:00
Nick Farrell
b770252a3b
Fix bug relating to reuse of types
When `LoadTypes` is being called, it does not include the
namespace-qualified types in its result. While these namespaces are
visible to `LoadTypes` itself, `RegisterTypes` will not recognise this
form of the types, only allowing them to be used if they are on the
schema path, and referred to without their namespace component.
2024-07-07 11:26:19 +10:00
Jack Christensen
c64fa0f0f2 Document that batched queries should not contain multiple statements 2024-07-03 22:49:02 -05:00
Jack Christensen
dced53f796 Better error message when reading past end of batch
https://github.com/jackc/pgx/issues/1801
2024-07-03 22:39:29 -05:00
Jack Christensen
161ce73ec1
Merge pull request #2046 from nicois/nicois/load-types
Faster/easier loading of types
2024-07-01 06:52:10 -05:00
Jack Christensen
fa57a20518 Update go.mod to require go 1.21
pgx now uses slices package.
2024-07-01 06:48:51 -05:00
Jack Christensen
dd71547340
Merge pull request #2066 from yuki0920/use-slices-contains
Fix: use `slices.Contains` according to the TODO comment
2024-07-01 06:47:06 -05:00
Nick Farrell
47977703e1
Load types using a single SQL query
When loading even a single type into pgx's type map, multiple SQL
queries are performed in series. Over a slow link, this is not ideal.
Worse, if multiple types are being registered, this is repeated multiple
times.

This commit add LoadTypes, which can retrieve type
mapping information for multiple types in a single SQL call, including
recursive fetching of dependent types.
RegisterTypes performs the second stage of this operation.
2024-07-01 15:34:17 +10:00
yuki0920
a764746906
Fix: use slices.Contains according to the TODO comment
I used `slices.Contains` according to the TODO comment.

```
// TODO replace by slices.Contains when experimental package will be merged to stdlib
// https://pkg.go.dev/golang.org/x/exp/slices#Contains
```
2024-06-30 07:58:56 +09:00
Jack Christensen
6b9ff972a4
Merge pull request #2061 from yann-soubeyrand/support-sslrootcert-system
Add support for sslrootcert=system
2024-06-29 06:32:18 -05:00
Yann Soubeyrand
c407c42692 Add support for sslrootcert=system 2024-06-25 11:15:40 +02:00
Jack Christensen
9907b874c2 Update pgservicefile
Fixes panic when parsing invalid file.
2024-06-06 07:12:26 -05:00
Jack Christensen
ec557e87d5
Merge pull request #2035 from exekias/fix-interval
Fix interval encoding to allow 0s and avoid extra spaces
2024-05-30 20:13:33 -05:00
Carlos Pérez-Aradros Herce
9f4a264f89 Fix interval encoding to allow 0s and avoid extra spaces
Fix a bugs introduced by 01d649b, also add some tests
2024-05-30 09:48:53 +02:00
Jack Christensen
572d7fff32 Release v5.6.0 2024-05-25 11:35:25 -05:00
Jack Christensen
b4911f1da7
Merge pull request #2019 from jackc/fix-encode-driver-valuer-on-pointer
Fix encode driver.Valuer on pointer
2024-05-25 11:20:25 -05:00
Jack Christensen
29751194ef Test composites on CockroachDB 2024-05-25 07:49:00 -05:00
Jack Christensen
c1f4cbb5cd Upgrade CockroachDB on CI 2024-05-25 07:48:47 -05:00
Hans-Joachim Kliemeck
24c0a5e8ff remove keepalive and rely on GOLANG default (since go 1.13 default is 15s)
https://www.reddit.com/r/golang/comments/d7v7dn/psa_go_113_introduces_15_sec_server_tcp/
2024-05-21 10:37:13 -05:00
Jack Christensen
9ca9203afb Move typed nil handling to Map.Encode from anynil
The new logic checks for any type of nil at the beginning of Encode and
then either treats it as NULL or calls the driver.Valuer method if
appropriate.

This should preserve the existing nil normalization while restoring the
ability to encode nil driver.Valuer values.
2024-05-18 22:39:28 -05:00
Jack Christensen
79cab4640f Only use anynil inside of pgtype 2024-05-18 21:06:23 -05:00
Jack Christensen
6ea2d248a3 Remove anynil.NormalizeSlice
anynil.Is was already being called in all paths that
anynil.NormalizeSlice was used.
2024-05-18 21:01:34 -05:00
Jack Christensen
c1075bfff0 Remove some special casing for QueryExecModeExec 2024-05-18 20:59:01 -05:00
Jack Christensen
cf6074fe5c Remove unused anynil.Normalize 2024-05-18 20:37:25 -05:00
Jack Christensen
13beb380f5 Fix encode driver.Valuer on nil-able non-pointers
https://github.com/jackc/pgx/issues/1566
https://github.com/jackc/pgx/issues/1860
https://github.com/jackc/pgx/pull/2019#discussion_r1605806751
2024-05-18 17:17:46 -05:00
Jack Christensen
fec45c802b Refactor appendParamsForQueryExecModeExec
Extract logic for finding OID and converting argument to encodable
value. This is in preparation for a future change for better supporting
nil driver.Valuer values.
2024-05-18 17:00:41 -05:00
Jack Christensen
3b7fa4ce87 Use go 1.20 in go.mod
Future commit will be using bytes.Clone which was implemented in Go
1.20.

Also update README.md to reflect that minimum supported Go version is
1.21. But only requiring Go 1.20 in go.mod to avoid needlessly breaking
old Go when it still works.
2024-05-18 16:47:44 -05:00
Mitar
732889728f Add support for custom JSON marshal and unmarshal.
The Codec interface is now implemented by *pgtype.JSONCodec
and *pgtype.JSONBCodec instead of pgtype.JSONCodec and
pgtype.JSONBCodec, respectively. This is technically a breaking
change, but it is extremely unlikely that anyone is depending on this,
and if there is downstream breakage it is trivial to fix.

Fixes #2005.
2024-05-18 08:02:09 -05:00
Mitar
e1b90cf620 Add ltree extension requirement. 2024-05-18 07:56:47 -05:00
Jack Christensen
2a36a7032e Fix encode driver.Valuer on pointer
pgx v5 introduced nil normalization for typed nils. This means that
[]byte(nil) is normalized to nil at the edge of the encoding system.
This simplified encoding logic as nil could be encoded as NULL and type
specific handling was unneeded.

However, database/sql compatibility requires Value to be called on a
nil pointer that implements driver.Valuer. This was broken by
normalizing to nil.

This commit changes the normalization logic to not normalize pointers
that directly implement driver.Valuer to nil. It still normalizes
pointers that implement driver.Valuer through implicit derefence.

e.g.

type T struct{}

func (t *T) Value() (driver.Value, error) {
  return nil, nil
}

type S struct{}

func (s S) Value() (driver.Value, error) {
  return nil, nil
}

(*T)(nil) will not be normalized to nil but (*S)(nil) will be.

https://github.com/jackc/pgx/issues/1566
2024-05-18 07:41:10 -05:00
Jack Christensen
ded01c0cd9 Fix TestParseConfigEnvLibpq unsetting envars
This would cause tests to fail if PG* variables were used for the
default connection config for other tests.

Previously broken by 0080acf318d162a1128928bc32eadf45cef61fd2.
2024-05-17 09:19:36 -05:00
ngavinsir
532bf8f583 adjust test 2024-05-14 20:28:02 -05:00
ngavinsir
169067a364 remove ctx from release tracer 2024-05-14 20:28:02 -05:00
ngavinsir
659525c961 trace release 2024-05-14 20:28:02 -05:00
ngavinsir
4dd1810d8b persist ctx in pgxpool conn 2024-05-14 20:28:02 -05:00
ngavinsir
25914e21f3 add release tracer 2024-05-14 20:28:02 -05:00
ngavinsir
19fcb54564 add pool to trace acquire 2024-05-14 20:28:02 -05:00
ngavinsir
a39632db43 feat: pgx pool acquire tracer 2024-05-14 20:28:02 -05:00
Oleksandr Redko
c05cce7d41 Fix test asserts: reverse expected-actual 2024-05-14 20:07:10 -05:00
Oleksandr Redko
0080acf318 Simplify config tests by using T.Setenv, T.TempDir 2024-05-14 20:06:18 -05:00
Mitar
c81bba8690 Use pgtype.PreallocBytes in LargeObject's Read.
Fixes #1876.
2024-05-14 07:03:24 -05:00
Pavlo Golub
523411a3fb make QueuedQuery.Fn property public, closes #1878
This commit fixes the overlook of the #1886 where SQL and Arguments
properties were exposed
2024-05-12 09:03:47 -05:00
Jack Christensen
a966716860 Replace DSN with keyword/value in comments and documentation
The term DSN is not used in the PostgreSQL documentation. I'm not sure
why it was originally used. Use the correct PostgreSQL terminology.
2024-05-11 14:33:35 -05:00
Jack Christensen
cf50c60869 Fix error check on CI 2024-05-11 14:33:13 -05:00
Jack Christensen
8db971660e Failed connection attempts include all errors
A single Connect("connstring") may actually make multiple connection
requests due to TLS or HA configuration. Previously, when all attempts
failed only the last error was returned. This could be confusing.
Now details of all failed attempts are included.

For example, the following connection string:

host=localhost,127.0.0.1,foo.invalid port=1,2,3

Will now return an error like the following:

failed to connect to `user=postgres database=pgx_test`:
	lookup foo.invalid: no such host
	[::1]:1 (localhost): dial error: dial tcp [::1]:1: connect: connection refused
	127.0.0.1:1 (localhost): dial error: dial tcp 127.0.0.1:1: connect: connection refused
	127.0.0.1:2 (127.0.0.1): dial error: dial tcp 127.0.0.1:2: connect: connection refused

https://github.com/jackc/pgx/issues/1929
2024-05-11 14:25:03 -05:00
Jack Christensen
48cdd7bab0 Allow scanning time without time zone into string
https://github.com/jackc/pgx/issues/2002
2024-05-10 10:52:41 -05:00
Jack Christensen
579a320c1c pgconn.SafeToRetry checks for wrapped errors
Use errors.As instead of type assertion.

Port 4e2e7a040579c1999c0766642d836eb28c6e2018 to v5

Credit to tjasko
2024-05-09 17:59:16 -05:00
Carlos Pérez-Aradros Herce
01d649b2bf Do not encode interval microseconds when they are 0
This make the encode match what postgres does
2024-05-09 17:29:13 -05:00
Jack Christensen
48ae1f4b2c Fix ResultReader.Read() to handle nil values
The ResultReader.Read() method was erroneously converting nil values
to []byte{}.

https://github.com/jackc/pgx/issues/1987
2024-05-09 17:13:26 -05:00
WGH
e4f72071f8 Document that generic helpers call rows.Close()
Existing generic helpers always call defer rows.Close().
Examples of their usage also omit external defer rows.Close() call.

For clarity, state that explicitly, because that's another point
why one would want to switch to generic helpers from manually written
rows.Next() loop.
2024-05-09 15:54:48 -05:00
Jack Christensen
6f0deff015 Add custom data to pgconn.PgConn
https://github.com/jackc/pgx/issues/1896
2024-05-09 15:39:28 -05:00
Jack Christensen
8649231bb3 Add ScanLocation to pgtype.TimestampCodec
If ScanLocation is set, the timestamps will be assumed to be in the
given location when scanning from the database.

The Codec interface is now implemented by *pgtype.TimestampCodec instead
of pgtype.TimestampCodec. This is technically a breaking change, but it
is extremely unlikely that anyone is depending on this, and if there is
downstream breakage it is trivial to fix.

https://github.com/jackc/pgx/issues/1195
https://github.com/jackc/pgx/issues/1945
2024-05-08 08:35:05 -05:00
Jack Christensen
33360ab479 Add ScanLocation to pgtype.TimestamptzCodec
If ScanLocation is set, it will be used to convert the time to the given
location when scanning from the database.

The Codec interface is now implemented by *pgtype.TimestamptzCodec
instead of pgtype.TimestamptzCodec. This is technically a breaking
change, but it is extremely unlikely that anyone is depending on this,
and if there is downstream breakage it is trivial to fix.

https://github.com/jackc/pgx/issues/1195
https://github.com/jackc/pgx/issues/1945
2024-05-08 08:35:05 -05:00
Jack Christensen
c31619d08b Improve docs for customizing context cancellation 2024-05-08 08:08:21 -05:00
Jack Christensen
ec9bb2ace7 Improve flickering test on CI 2024-05-08 07:54:17 -05:00
Jack Christensen
93a579754b Add CancelRequestContextWatcherHandler
This allows a context to cancel a query by sending a cancel request to
the server before falling back to setting a deadline.
2024-05-08 07:41:02 -05:00
Jack Christensen
42c9e9070a Allow customizing context canceled behavior for pgconn
This feature made the ctxwatch package public.
2024-05-08 07:41:02 -05:00
Oleksandr Redko
60a01d044a Fix typos in doc comments 2024-04-17 12:00:02 -05:00
Zach Olstein
8f69e45a53 fixup! Cache reflection analysis in RowToStructBy... 2024-04-16 13:08:16 -05:00
Zach Olstein
ec98406207 Cache reflection analysis in RowToStructBy...
Modify the RowToStructByPos/Name functions to store the computed mapping
of columns to struct field locations in a cache to reuse between calls.
Because this computation can be expensive and the same few results will
frequently be reused, caching these results provides a significant
speedup.

For positional mappings, we can key the cache by just the struct-type.
However, for named mappings, the key must include a representation of
the columns, in order, since different columns produce different
mappings.
2024-04-16 13:08:16 -05:00
Jack Christensen
8db0f280fb Add benchmarks for RowToStructBy(Pos|Name) 2024-04-16 12:59:40 -05:00
Felix Röhrich
fc416d237a make parsing stricter and add corresponding test 2024-04-16 12:18:06 -05:00
Jack Christensen
a3d9120636 Add SeverityUnlocalized field to PgError / Notice
https://github.com/jackc/pgx/issues/1971
2024-04-07 08:58:10 -05:00
Carlos Pérez-Aradros Herce
78b22c3d2f fix tests 2024-03-20 18:21:11 -05:00
Carlos Pérez-Aradros Herce
221ad1b84c Add support for macaddr8 type
Postgres also has a `macaddr8` type, this PR adds support for it, using
the same codec as `macaddr`
2024-03-20 18:21:11 -05:00
Tomas Zahradnicek
b6e5548341 StrictNamedArgs 2024-03-16 10:59:31 -05:00
Jack Christensen
1b6227af11 Remove verbose flag from go test command on CI
It is more often that interesting information is buried by the verbose
output than the verbose output is useful. It can be reenabled later if
necessary.
2024-03-16 09:52:50 -05:00
Jack Christensen
c1fce377ee Test Go 1.22 and drop Go 1.20 from testing on CI 2024-03-16 09:44:23 -05:00
Jack Christensen
7fd6f2a4f5 Disable parallel testing on Github Actions CI
Tests were failing with:
Error: Process completed with exit code 143.

This appears to mean that Github Actions killed the runner.

See https://github.com/jackc/pgx/actions/runs/8216337993/job/22470808811
for an example.

It appears Github Actions kills runners based on resource usage. Running
tests one at a time reduces the resource usage and avoids the problem.

Or at least that's what I presume is happening. It sure is fun debugging
issues on cloud systems where you have limited visibility... :(

fixes https://github.com/jackc/pgx/issues/1934
2024-03-16 09:41:51 -05:00
Jack Christensen
78a0a2bf41 Fix spelling in changelog 2024-03-09 12:16:20 -06:00
Jack Christensen
a17f064492 Update changelog 2024-03-09 12:12:41 -06:00
Jack Christensen
49b6aad319 Use spaces instead of parentheses for SQL sanitization
This still solves the problem of negative numbers creating a line
comment, but this avoids breaking edge cases such as `set foo to $1`
where the substition is taking place in a location where an arbitrary
expression is not allowed.

https://github.com/jackc/pgx/issues/1928
2024-03-09 12:09:42 -06:00
Felix
0cc4c14e62 Add test to validate CollectRows for empty Rows
https://github.com/jackc/pgx/issues/1924
https://github.com/jackc/pgx/issues/1925
2024-03-06 22:05:32 -06:00
Jack Christensen
da6f2c98f2 Update changelog 2024-03-04 09:12:06 -06:00
Jack Christensen
c543134753 SQL sanitizer wraps arguments in parentheses
pgx v5 was not vulnerable to CVE-2024-27289 do to how the sanitizer was
being called. But the sanitizer itself still had the underlying issue.
This commit ports the fix from pgx v4 to v5 to ensure that the issue
does not emerge if pgx uses the sanitizer differently in the future.
2024-03-04 09:09:42 -06:00
Jack Christensen
20344dfae8 Check for overflow on uint16 sizes in pgproto3 2024-03-04 09:09:29 -06:00
Jack Christensen
adbb38f298 Do not allow protocol messages larger than ~1GB
The PostgreSQL server will reject messages greater than ~1 GB anyway.
However, worse than that is that a message that is larger than 4 GB
could wrap the 32-bit integer message size and be interpreted by the
server as multiple messages. This could allow a malicious client to
inject arbitrary protocol messages.

https://github.com/jackc/pgx/security/advisories/GHSA-mrww-27vc-gghv
2024-03-04 09:09:29 -06:00
Felix
c1b0a01ca7 Fix behavior of CollectRows to return empty slice if Rows are empty
https://github.com/jackc/pgx/issues/1924
2024-03-03 07:52:18 -06:00
Jack Christensen
88dfc22ae4 Fix simple protocol encoding of json.RawMessage
The underlying type of json.RawMessage is a []byte so to avoid it being
considered binary data we need to handle it specifically. This is done
by registerDefaultPgTypeVariants. In addition, handle json.RawMessage in
the JSONCodec PlanEncode to avoid it being mutated by json.Marshal.

https://github.com/jackc/pgx/issues/1763
2024-03-02 15:12:20 -06:00
Jack Christensen
2e84dccaf5 *Pipeline.getResults should close pipeline on error
Otherwise, it might be possible to panic when closing the pipeline if it
tries to read a connection that should be closed but still has a fatal
error on the wire.

https://github.com/jackc/pgx/issues/1920
2024-02-29 18:44:01 -06:00
David Kurman
d149d3fe5c Fix panic in TryFindUnderlyingTypeScanPlan
Check if CanConvert before calling reflect.Value.Convert
2024-02-26 17:51:56 -06:00
Jack Christensen
046f497efb deallocateInvalidatedCachedStatements now runs in transactions
https://github.com/jackc/pgx/issues/1847
2024-02-24 10:16:18 -06:00
Jack Christensen
8896bd6977 Handle invalid sslkey file
https://github.com/jackc/pgx/issues/1915
2024-02-24 09:24:26 -06:00
Jack Christensen
85f15c4b3c Fix scan float4 into sql.Scanner
https://github.com/jackc/pgx/issues/1911
2024-02-23 18:18:03 -06:00
Jack Christensen
654dcab93e Fix: pgtype.Bits makes copy of data from read buffer
It was taking a reference. This would cause the data to be corrupted by
future reads.

fixes #1909
2024-02-23 17:40:11 -06:00
Tom Payne
5c63f646f8 Add link to github.com/twpayne/pgx-geos 2024-02-04 22:04:03 -06:00
Jack Christensen
6f8f6ede6c Update changelog for v5.5.3 2024-02-03 12:52:29 -06:00
Jack Christensen
576b6c88f6 Bump actions/setup-go version
This gets rid of some deprecation warnings on Github Actions.
2024-02-03 12:50:20 -06:00
Jack Christensen
7caa448ac8 Skip test on CockroachDB 2024-02-03 12:41:59 -06:00
Jack Christensen
832b4f9771 Fix: prepared statement already exists
When a conn is going to execute a query, the first thing it does is to
deallocate any invalidated prepared statements from the statement cache.
However, the statements were removed from the cache regardless of
whether the deallocation succeeded. This would cause subsequent calls of
the same SQL to fail with "prepared statement already exists" error.

This problem is easy to trigger by running a query with a context that
is already canceled.

This commit changes the deallocate invalidated cached statements logic
so that the statements are only removed from the cache if the
deallocation was successful on the server.

https://github.com/jackc/pgx/issues/1847
2024-02-03 12:33:17 -06:00
Jack Christensen
fd4411453f Improve Conn.LoadType documentation 2024-02-03 10:29:10 -06:00
Jack Christensen
34da2fed95 Improve CopyFrom auto-conversion of text-ish values
CopyFrom requires that all values are encoded in the binary format. It
already tried to parse strings to values that can then be encoded into
the binary format. But it didn't handle types that can be encoded as
text and then parsed and converted to binary. It now does.
2024-02-03 09:49:56 -06:00
Jack Christensen
7b5fcac465 Add timetz and []timetz OID constants
https://github.com/jackc/pgx/issues/1883
2024-01-27 18:55:59 -06:00
Jack Christensen
0819a17da8 Remove openssl from TLS test setup
TLS setup and tests were rather finicky. It seems that openssl 3
encrypts certificates differently than older openssl and it does it in
a way Go and/or pgx ssl handling code can't handle. It appears that
this related to the use of a deprecated client certificate encryption
system.

This caused CI to be stuck on Ubuntu 20.04 and recently caused the
contributing guide to fail to work on MacOS.

Remove openssl from the test setup and replace it with a Go program
that generates the certificates.
2024-01-27 09:04:19 -06:00
Florent Viel
bf1c1d7848 create ltree extension in pg setup for tests 2024-01-26 09:06:13 -06:00
Florent Viel
0fa533386c add ltree pgtype support 2024-01-26 09:06:13 -06:00
Pavlo Golub
c90f82a4e3 make properties of QueuedQuery and Batch public, closes #1878 2024-01-25 18:03:59 -06:00
Edoardo Spadolini
a57bb8caea Add AppendRows helper 2024-01-23 17:14:24 -06:00
Kirill Malikov
517c654e2c feat: fast encodeUUID 2024-01-20 20:50:01 -06:00
Mitar
a4ca0917da Support large large objects.
Fixes #1865.
2024-01-15 08:50:55 -06:00
Mitar
0c35c9e630 Revert "Document max read and write sizes for large objects"
This reverts commit b99e2bb7e0818428092e955cb0ee9cff45504bfd.
2024-01-15 08:50:55 -06:00
Jack Christensen
b7de418d46 Release v5.5.2 2024-01-13 11:08:35 -06:00
Jack Christensen
b99e2bb7e0 Document max read and write sizes for large objects
https://github.com/jackc/pgx/issues/1865
2024-01-13 10:43:35 -06:00
Jack Christensen
52f2151422 Allow NamedArgs to start with underscore
fixes #1869
2024-01-13 10:20:25 -06:00
Endre Kovács
dfb6489612 fix typo in doc.go 2024-01-13 09:40:00 -06:00
Chris Frank
9346d48035 fix OpenDBFromPool example 2024-01-13 09:39:16 -06:00
jeremy.spriet
1fdd17041a feat(pgproto3): expose MaxExpectedBodyLen and ActualBodyLen in ExceededMaxBodyLenErr struct 2024-01-12 18:21:07 -06:00
Jack Christensen
f654d61d79 Make note about possible parse config error message redaction change 2024-01-12 17:56:13 -06:00
Jack Christensen
5d26bbefd8 Make pgconn.ConnectError and pgconn.ParseConfigError public
fixes #1773
2024-01-12 17:52:25 -06:00
vahid-sohrabloo
44768b5a01 fix a typo in config_test.go
fix a typo in config_test.go
2024-01-12 17:36:43 -06:00
Jack Christensen
6f2ce92356 Upgrade golang.org/x/crypto to v0.17.0
pgx is unaffected by CVE-2023-48795 because it does not use SSH.
However, dependabot and other vulnerability scanners may complain so
bump the dependency anyway.
2023-12-29 18:14:09 -06:00
Tikhon Fedulov
4367ee0598 Update TestRowToStructByName with snake case support 2023-12-25 09:47:10 -06:00
Tikhon Fedulov
d2c9ebc2ef Use local variables in fieldPosByName and fix errors 2023-12-25 09:47:10 -06:00
Tikhon Fedulov
0c7acf9481 Add snake_case support to RowToStructByName 2023-12-25 09:47:10 -06:00
Jack Christensen
cbc5a7055f Fix: close conn on read failure in pipeline
Suggested by @jameshartig in https://github.com/jackc/pgx/issues/1847
2023-12-23 12:11:23 -06:00
James Hartig
4c14caae07 update description cache after exec prepare 2023-12-23 12:08:02 -06:00
James Hartig
22fe50149b pgconn: check if pipeline i closed in Sync/GetResults
Otherwise there will be a nil pointer exception accessing the conn
2023-12-23 12:04:21 -06:00
Ryan Fowler
dfd198003a Fix panic in Pipeline when PgConn is busy or closed 2023-12-23 10:30:59 -06:00
jeremy.spriet
603c8c1e90 feat(pgproto3/backend): add a SetMaxBodyLen to limit the max body length for the receive 2023-12-23 10:25:35 -06:00
Samuel Stauffer
9ab9e3c40b Unwrap errors in normalizeTimeoutError 2023-12-16 11:15:35 -06:00
Samuel Stauffer
2daeb8dc5f pgconn: normalize starTLS connection error
Normalize the error that is returned by startTLS in pgconn.connect. This
makes it possible to determine if the error was a context error.
2023-12-16 11:15:35 -06:00
Jack Christensen
df3c5f4df8 Use "Pg" instead of "PG" in new PgError related identifiers
Arguably, PGError might have been better. But since the precedent is
long since established it is better to be consistent.
2023-12-15 18:33:51 -06:00
James Hartig
b1631e8e35 pgconn: add OnPGError to Config for error handling
OnPGError is called on every error response received from Postgres and can
be used to close connections on specific errors. Defaults to closing on
FATAL-severity errors.

Fixes #1803
2023-12-15 18:29:32 -06:00
Jack Christensen
ba05097642 Release v5.5.1 2023-12-09 12:59:44 -06:00
Evan Jones
384fe7775c Batch.Queue: document always uses the conn's DefaultQueryExecMode
The only way to change the query mode used by Batch.Queue and
SendBatch is to use a connection with a different
DefaultQueryExecMode. Add this to the function documentation.

Conn.SendBatch: Move where mode is defined to make this clearer in
the code. I spent time looking for the option that does not exist.
2023-12-09 11:47:56 -06:00
Eshton Robateau
20bf953a17 pull out changes into new public function 2023-12-09 11:20:14 -06:00
Eshton Robateau
12582a0fd4 bitsize largest option is 64 2023-12-09 11:20:14 -06:00
Eshton Robateau
905f252667 uncomment tests 2023-12-09 11:20:14 -06:00
Eshton Robateau
9927e14bbf remove dead line 2023-12-09 11:20:14 -06:00
Eshton Robateau
95b2f85e60 support scientific notation big floats 2023-12-09 11:20:14 -06:00
Jack Christensen
913e4c8487 Update changelog 2023-12-02 09:36:03 -06:00
Jack Christensen
31321c2017 Add race detector to bug report template 2023-12-02 09:27:57 -06:00
maksymnevajdev
319c3172f2 fix panic in prepared sql 2023-12-01 18:34:41 -06:00
Simon Paredes
4678e69599 fix error message to print the unexpected rune 2023-12-01 18:23:23 -06:00
Simon Paredes
89d699c2e8 wrap errors instead of just formatting them 2023-12-01 18:23:23 -06:00
Jacopo
7ebced92b5 Fix issue with order of json encoding #1805 2023-11-24 19:01:48 -06:00
Sam Whited
94e56e61ba Fix usage of logger in stdlib docs
The documentation previously showed the old way of logging and not the
newer tracer adapter. This patch updates the example to build correctly
with pgx/v5.

Signed-off-by: Sam Whited <sam@samwhited.com>
2023-11-22 08:15:05 -06:00
Jack Christensen
9103457384 Improve docs 2023-11-18 07:44:24 -06:00
Jack Christensen
9782306287 Only remove statement from map if deallocate succeeds
https://github.com/jackc/pgx/pull/1795
2023-11-18 07:44:24 -06:00
Jack Christensen
7d5a3969d0 Improve docs and tests 2023-11-18 07:44:24 -06:00
Jack Christensen
e5015e2fac pgx.Conn.Deallocate uses PgConn.Deallocate
This uses the PostgreSQL protocol to deallocate a prepared statement
instead of a SQL statement. This allows it to work even in an aborted
transaction.
2023-11-18 07:44:24 -06:00
Jack Christensen
4dbd57a7ed Add PgConn.Deallocate method
This method uses the PostgreSQL protocol Close method to deallocate a
prepared statement. This means that it can succeed in an aborted
transaction.
2023-11-18 07:44:24 -06:00
Jack Christensen
0570b0e196 Better document PgConn.Prepare implementation 2023-11-18 07:44:24 -06:00
Jack Christensen
df5d00eb60 Remove PostgreSQL 11 from supported versions 2023-11-11 10:09:47 -06:00
robford
d38dd85756 Allowed nxtf to signal end of data by returning nil,nil
Added some test
Improved documentation
2023-11-11 10:06:58 -06:00
robford
9b6d3809d6 added tests 2023-11-11 10:06:58 -06:00
robford
b4d72d4fce copyFromFunc 2023-11-11 10:06:58 -06:00
robford
ccdd85a5eb added ChopyFromCh 2023-11-11 10:06:58 -06:00
Jack Christensen
96f5f9cd95 Release v5.5.0 2023-11-04 10:27:32 -05:00
Kirill Mironov
d3fb6e00da implement json.Marshaler and json.Unmarshaler for Float4, Float8 2023-11-04 10:25:31 -05:00
Jack Christensen
cf6ef75f91 stdlib: Use Ping instead of CheckConn in ResetSession
CheckConn is deprecated. It doesn't detect all network outages. It
causes a 1ms delay while it tries to read the connection. Ping incurs a
round trip but that means it is a much stronger guarantee that the
connection is usable. In addition, if the application and the database
are on the same network it will actually be faster as round trip times
are typically a few hundred microseconds.
2023-10-26 20:41:44 -05:00
Jack Christensen
7a4bb7edb5
Add link to pgx presentation to README.md 2023-10-20 18:49:41 -05:00
Ivan Posazhennikov
6f7400f428 fix typo in the comment in the pgconn.go 2023-10-14 18:02:35 -05:00
Anton Levakin
304697de36 CancelRequest: Wait for the cancel request to be acknowledged by the server 2023-10-14 17:48:16 -05:00
Anton Levakin
5d0f904831 update TestConnContextCanceledCancelsRunningQueryOnServer
Check cancellation of the request for pgbouncer
2023-10-14 17:48:16 -05:00
Anton Levakin
6ca3d8ed4e Revert "CancelRequest: don't try to read the reply"
This reverts commit c861bce438ee5b96cc2dcc78718731dce6949060.
2023-10-14 17:48:16 -05:00
Jack Christensen
81ddcfdefb Fix spurious deadline exceeded error
stdlib_test.TestConnConcurrency had been flickering on CI deadline /
timeout errors. This was extremely confusing because the test deadline
was set for 2 minutes and the errors would occur much quicker.

The problem only manifested in an extremely specific and timing
sensitive situation.

1. The watchdog timer for deadlocked writes starts the goroutine to
   start the background reader
2. The background reader is stopped
3. The next operation is a read without a preceding write (AFAIK only
   CheckConn does this)
4. The deadline is set to interrupt the read
5. The goroutine from 1 actually starts the background reader
6. The background reader gets an error reading the connection with the
   deadline
7. The deadline is cleared
8. The next read on the connection will get the timeout error
2023-10-14 11:38:33 -05:00
Jack Christensen
45f807fdb4 Special case the underlying type of []byte
Underlying types were already tried. But []byte is not a normal
underlying type. It is a slice. But since is can be treated as a scalar
instead of an array / slice we need to special case it.

https://github.com/jackc/pgx/issues/1763
2023-10-12 20:52:49 -05:00
Jack Christensen
8a09979417 Skip test on CockroachDB 2023-10-10 22:07:06 -05:00
Jack Christensen
7a2b93323c Prevent prematurely closing statements in database/sql
This error was introduced by 0f0d23659950bbf7a1677e50aac09b1e29ad7c60.
If the same statement was prepared multiple times then whenever Close
was called on one of the statements the underlying prepared statement
would be closed even if other statements were still using it.

https://github.com/jackc/pgx/issues/1754#issuecomment-1752004634
2023-10-10 21:56:26 -05:00
Nicola Murino
1484fec57f CI: add PostgreSQL 16 2023-10-10 20:54:54 -05:00
Nicola Murino
3957163808 Update supported Go versions and add 1.21 to CI 2023-10-10 20:54:54 -05:00
Jack Christensen
7fc908a5f2 Do not call t.Fatal in goroutine
require.Equal internally calls t.Fatal, which is not safe to call in a
goroutine.
2023-10-07 10:37:24 -05:00
Jack Christensen
0f0d236599 database/sql prepared statement names are deterministically generated
stdlib now uses the functionality introduced in
bbe2653bc51361e5d7607729356344ef979a9f5a for prepared statements. This
means that the prepared statement name is stable for a given query even
across connections and program executions.

It also makes tracing easier.

See https://github.com/jackc/pgx/issues/1754
2023-10-07 10:16:25 -05:00
Ville Skyttä
c6c50110db Spelling and grammar fixes 2023-10-07 09:26:23 -05:00
Jack Christensen
91530db629 Fix typo in string.Cut refactor 2023-10-07 09:20:28 -05:00
Ville Skyttä
24ed0e4257 Make use of strings.Cut 2023-10-04 20:41:55 +03:00
Jack Christensen
163eb68866 Normalize timeout error when receiving pipeline results
https://github.com/jackc/pgx/issues/1748#issuecomment-1740437138
2023-09-30 08:50:40 -05:00
Jack Christensen
a61517a83b SendBatch should pass ctx to StartPipeline
https://github.com/jackc/pgx/issues/1748
2023-09-28 20:00:02 -05:00
Vincent Le Goff
d93f31b8fa docs: GetPoolConnector 2023-09-25 08:51:12 -05:00
Jack Christensen
cf72a00f52 Skip test of unsupported operation on CockroachDB 2023-09-23 10:49:11 -05:00
Jack Christensen
c08cc72306 Improve QueryExecModeCacheDescribe and clarify documentation
QueryExecModeCacheDescribe actually is safe even when the schema or
search_path is modified. It may return an error on the first execution
but it should never silently encode or decode a value incorrectly. Add a
test to demonstrate and ensure this behavior.

Update documentation of QueryExecModeCacheDescribe to remove warning of
undetected result decoding errors.

Update documentation of QueryExecModeCacheStatement and
QueryExecModeCacheDescribe to indicate that the first execution of an
invalidated statement may fail.
2023-09-23 10:35:42 -05:00
Jack Christensen
7de53a958b stmtcache: Use deterministic, stable statement names
Statement names are now a function of the SQL. This may make database
diagnostics, monitoring, and profiling easier.
2023-09-23 09:55:05 -05:00
Jack Christensen
bbe2653bc5 Prepare chooses statement name based on sql if name == sql
This makes it easier to explicitly manage prepared statements.

refs #1716
2023-09-23 08:40:06 -05:00
Mochammad Hanif R
4e7aa59d64 Fix typos in docs 2023-09-23 07:20:36 -05:00
Lev Zakharov
b301530a5f add doc for OpenDBFromPool 2023-09-09 08:13:56 -05:00
Lev Zakharov
f42824cab3 update docs 2023-09-09 08:13:56 -05:00
Lev Zakharov
18856482c4 remove before/after acquire hooks 2023-09-09 08:13:56 -05:00
Lev Zakharov
639691c0ab add test for stdlib.OpenDBFromPool 2023-09-09 08:13:56 -05:00
Lev Zakharov
3e716c4b06 add example to the doc 2023-09-09 08:13:56 -05:00
Lev Zakharov
51ade172e5 refactor to use the same connection implementation 2023-09-09 08:13:56 -05:00
Lev Zakharov
3d4540aa1b add *sql.DB construction from *pgxpool.Pool 2023-09-09 08:13:56 -05:00
Julien GOTTELAND
389931396e No data result on error 2023-08-19 18:31:41 -05:00
Julien GOTTELAND
9ee7d29cf9 Add CollectExactlyOneRow function 2023-08-19 18:31:41 -05:00
Craig Ringer
a7375cc503 docs: Emphasise need to call rows.Err() after rows.Next() returns false
The Rows interface in pgx, like its ancestor in database/sql, is easy to
accidentally misuse in a way that can cause apps to misinterpret
database or connection errors as successful queries with empty or
truncated result-sets.

Update the docs to emphasise the need to call rows.Err() after
rows.Next() returns false, and direct users of the interface to the v5
API helpers that make writing correct code easier.

The docs on Conn.Query() already call this out, so only a small change
is needed to warn users and point them at the details on Query()

Per details in #1707
2023-08-10 17:19:15 -05:00
Jack Christensen
d43bd349c1 Add batch insert benchmarks 2023-08-08 18:44:31 -05:00
Jack Christensen
5c6cf62b53 Fix off by one error in benchmark 2023-08-08 18:38:34 -05:00
Jack Christensen
d17440d5c7 Add missed changelog entry and fix typo 2023-08-05 08:36:48 -05:00
Jack Christensen
4c60839c48 Release v5.4.3 2023-08-05 08:24:37 -05:00
Jack Christensen
e9087eacb8 Fix data race when pgproto3 trace is enabled during CopyFrom
https://github.com/jackc/pgx/issues/1703
2023-08-05 07:30:59 -05:00
Jack Christensen
d626dfe94e TestConnConcurrency has been failing on CI
This probably won't fix it, but at the very least we should not be
running assertions in a goroutine.
2023-07-28 18:13:31 -05:00
Jack Christensen
1a9b2a53a5 Fix staticcheck issues 2023-07-28 18:04:31 -05:00
Alexey Palazhchenko
8fb309c631 Use Go 1.20's link syntax for ParseConfig 2023-07-28 17:51:42 -05:00
horpto
f4533dc906 optimize parseNumericString 2023-07-25 19:25:23 -05:00
Jack Christensen
4091eedf03 Check out code before setting up Go
This allows for caching the Go dependencies.
2023-07-22 17:13:30 -05:00
Jack Christensen
87d771ef9c Prettier ci.yml 2023-07-22 17:12:56 -05:00
Jack Christensen
492283b90b zeronull.Timestamptz should use pgtype.Timestamptz
https://github.com/jackc/pgx/issues/1694
2023-07-22 08:35:32 -05:00
James Hartig
e665f74c99 fix TestPoolBackgroundChecksMinConns and NewConnsCount
Previously it was checking TotalConns but that includes ConstructingConns.
Instead it should directly check IdleConns so the next Acquire takes one of
those and doesn't make a 3rd connection. The check against the context was
also wrong which prevented this from timing out after 2 minutes.

This also fixes a bug where NewConnsCount was not correctly counting
connections created by Acquire directly.

Fixes #1690
2023-07-22 08:28:39 -05:00
Rafi Shamim
f90e86fd8d Unskip TestConnCopyFromLarge for CockroachDB
This test is passing now.
2023-07-22 07:11:47 -05:00
Jack Christensen
88b49d48f6 Disable TestPoolBackgroundChecksMinConns on Windows
https://github.com/jackc/pgx/issues/1690
2023-07-19 21:20:26 -05:00
Jack Christensen
2506cf3666 Make CI badge link 2023-07-19 21:12:49 -05:00
Jack Christensen
d58fe2d53c Fix json scan of non-string pointer to pointer
https://github.com/jackc/pgx/issues/1691
2023-07-19 20:54:05 -05:00
Jack Christensen
ef9e26a5d5 Check nil in defer
A panic might mean that pbr is nil.

https://github.com/jackc/pgx/issues/1689
2023-07-15 10:16:28 -05:00
Evan Jones
6703484a0d go.mod: run go mod tidy; removes golang.org/x/sys
I'm not sure exactly what commit removed the usage of this module,
but it seems worth simplifying the dependencies.
2023-07-15 10:12:11 -05:00
Jack Christensen
c513e2e435 Fix: pgxpool: background health check cannot overflow pool
It was previously possible for a connection to be created while the
background health check was running. The health check could create
connection(s) in excess of the maximum pool size in this case.

https://github.com/jackc/pgx/issues/1660
2023-07-15 10:09:53 -05:00
smaher-edb
f47f0cf823 connect_timeout is not obeyed for sslmode=allow|prefer
connect_timeout given in conn string was not obeyed if sslmode is not specified (default is prefer) or equals sslmode=allow|prefer. It took twice the amount of time specified by connect_timeout in conn string. While this behavior is correct if multi-host is provided in conn string, it doesn't look correct in case of single host. This behavior was also not matching with libpq.

The root cause was to implement sslmode=allow|prefer conn are tried twice. First with TLSConfig and if that doesn't work then without TLSConfig. The fix for this issue now uses the same context if same host is being tried out. This change won't affect the existing multi-host behavior.

This PR goal is to close issue [jackc/pgx/issues/1672](https://github.com/jackc/pgx/issues/1672)
2023-07-15 09:49:09 -05:00
Christoph Engelbert (noctarius)
bd3e0d422c Fixes #1684 QCharArrayOID being defined with the wrong OID 2023-07-15 09:44:48 -05:00
Jack Christensen
2f6fcf8eb0 RowTo(AddrOf)StructByPos ignores fields with "-" db tag
https://github.com/jackc/pgx/discussions/1682
2023-07-15 09:39:20 -05:00
Jack Christensen
038fc448c1 Release v5.4.2 2023-07-11 21:29:54 -05:00
Jack Christensen
95aa87f2e8 exitPotentialWriteReadDeadlock stops bgReader
It's not enough to stop the slowWriteTimer, because the bgReader may
have been started.
2023-07-11 21:29:11 -05:00
Jack Christensen
f512b9688b Add PgConn.SyncConn
This provides a way to ensure it is safe to directly read or write to
the underlying net.Conn.

https://github.com/jackc/pgx/issues/1673
2023-07-11 21:29:11 -05:00
Jack Christensen
05440f9d3f Drastically increase allowed test times for potato CI
The context timeouts for tests are designed to give a better error
message when something hangs rather than the test just timing out.
Unfortunately, the potato CI frequently has some test or another
randomly take a long time. While the increased times are somewhat less
than optimal on a real computer, hopefully this will solve the
flickering CI.
2023-07-11 21:16:08 -05:00
Jack Christensen
e0c70201dc Skip json format test on CockroachDB 2023-07-11 20:51:22 -05:00
Jack Christensen
524f661136 Fix JSON encoding for pointer to structs implementing json.Marshaler
https://github.com/jackc/pgx/issues/1681
2023-07-11 20:28:36 -05:00
Dan McGee
507a9e9ad3 Remove some now unused pgtype code
Most of this is in conversion, and I assume it became unused with some
of the v5 changes and refactors to a codec-based approach.

There are likely a few more cleanups to be made, but these ones seemed
easy and safe to start with.
2023-07-10 20:23:42 -05:00
Dan McGee
0328d314ea Use bytes.Equal rather than bytes.Compare ==/!= 0
As recommended by go-staticcheck, but also might be a bit more efficient
for the compiler to implement, since we don't care about which slice of
bytes is greater than the other one.
2023-07-08 12:08:05 -05:00
Jack Christensen
cd46cdd450 Recreate the frontend in Construct with the new bgReader
https://github.com/jackc/pgx/pull/1629#discussion_r1251472215
2023-07-08 11:39:39 -05:00
Adrian-Stefan Mares
2bf5a61401 fix: Do not use infinite timers 2023-07-08 11:24:39 -05:00
Evan Jones
dc94db6b3d pgtype.Hstore: add a round-trip test for binary and text codecs
This ensures the output of Encode can pass through Scan and produce
the same input. This found two two minor problems with the text
codec. These are not bugs: These situations do not happen when using
pgx with Postgres. However, I think it is worth fixing to ensure the
code is internally consistent.

The problems with the text codec are:

* It did not correctly distinguish between nil and empty. This is not
  a problem with Postgres, since NULL values are marked separately,
  but the binary codec distinguishes between them, so it seems like
  the text codec should as well.
* It did not output spaces between keys. Postgres produces output in
  this format, and the parser now only strictly parses the Postgres
  format. This is not a bug, but seems like a good idea.
2023-06-29 17:25:47 -05:00
Gerasimos (Makis) Maropoulos
b68e7b2a68 README: Add kataras/pgx-golog to 3rd-party loggers 2023-06-24 18:23:15 -05:00
Brandon Kauffman
1dd69f86a1 Enable failover efforts when pg_hba.conf disallows non-ssl connections
Copy of https://github.com/jackc/pgconn/pull/133
2023-06-24 06:41:35 -05:00
Jack Christensen
8e6cf8f3a5 Add comment to test 2023-06-20 08:49:33 -05:00
Jack Christensen
91cba90e8d Fix: RowScanner errors are fatal to Rows
https://github.com/jackc/pgx/issues/1654
2023-06-20 08:48:06 -05:00
Jack Christensen
0d14b87140 Because CI runs on a potato 2023-06-20 08:43:06 -05:00
Nicola Murino
e79efdacf9 CI: run tests in verbose mode
It's helpful to have more detailed logs, such as how long it took to
run a single test
2023-06-19 17:06:21 -05:00
Nicola Murino
20a40120ed TestQueryEncodeError: crdb now returns the same error as postgres 2023-06-19 17:06:21 -05:00
Nicola Murino
aa263d4352 CockroachDB tests: use a more recent version 2023-06-19 17:06:21 -05:00
Nicola Murino
7fccc604af stdlib: add a concurrency test 2023-06-19 17:06:21 -05:00
Jack Christensen
34f17a6048 Allow more time for test on slow CI 2023-06-18 08:36:03 -05:00
Jack Christensen
74ab538d2a Release v5.4.1 2023-06-18 08:27:21 -05:00
Lev Zakharov
7c386112e3 fix concurrency bug in pgtype.defaultMap (#1650) 2023-06-18 08:23:56 -05:00
Jack Christensen
9a5ead9048 Add TxOptions.BeginQuery to allow overriding the default BEGIN query
https://github.com/jackc/pgx/issues/1643
2023-06-18 06:43:17 -05:00
Jack Christensen
737b5af236 Allow more time for test on slow CI 2023-06-17 19:03:15 -05:00
Jack Christensen
f20070650f Make TestPoolBackgroundChecksMinConns less timing sensitive for CI 2023-06-17 17:34:07 -05:00
Evan Jones
e5db6a0467 pgtype array: Fix encoding of vtab \v
Arrays with values that start or end with vtab ("\v") must be quoted.
Postgres's array parser skips leading and trailing whitespace with
the array_isspace() function, which is slightly different from the
scanner_isspace() function that was previously linked. Add a test
that reproduces this failure, and fix the definition of isSpace.

This also includes a change to use strings.EqualFold which should
really not matter, but does not require copying the string.
2023-06-17 17:15:58 -05:00
Jack Christensen
5b7cc8e215 Make TestConnCheckConn less timing sensitive for CI 2023-06-17 17:12:58 -05:00
Evan Jones
bc8b1ca320 remove the single backing string optimization
This is a bit slower than using this optimization, but ensures this
version does not change garbage collection behavior. This does still
using a single []string for all the *string value pointers because
that is what text parsing already does. This makes the two behave
similarly.

benchstat results of master versus this version:

                               │  orig.txt   │   new-binary-no-share-string.txt    │
                               │   sec/op    │   sec/op     vs base                │
HstoreScan/databasesql.Scan-10   82.11µ ± 1%   81.71µ ± 2%        ~ (p=0.280 n=10)
HstoreScan/text-10               83.30µ ± 1%   82.45µ ± 1%   -1.02% (p=0.000 n=10)
HstoreScan/binary-10             15.99µ ± 2%   10.12µ ± 1%  -36.67% (p=0.000 n=10)
geomean                          47.82µ        40.86µ       -14.56%

                               │   orig.txt   │   new-binary-no-share-string.txt    │
                               │     B/op     │     B/op      vs base               │
HstoreScan/databasesql.Scan-10   56.23Ki ± 0%   56.23Ki ± 0%       ~ (p=0.128 n=10)
HstoreScan/text-10               65.12Ki ± 0%   65.12Ki ± 0%       ~ (p=0.541 n=10)
HstoreScan/binary-10             21.09Ki ± 0%   19.87Ki ± 0%  -5.75% (p=0.000 n=10)
geomean                          42.58Ki        41.75Ki       -1.95%

                               │  orig.txt  │    new-binary-no-share-string.txt    │
                               │ allocs/op  │ allocs/op   vs base                  │
HstoreScan/databasesql.Scan-10   744.0 ± 0%   744.0 ± 0%        ~ (p=1.000 n=10) ¹
HstoreScan/text-10               743.0 ± 0%   743.0 ± 0%        ~ (p=1.000 n=10) ¹
HstoreScan/binary-10             464.0 ± 0%   316.0 ± 0%  -31.90% (p=0.000 n=10)
geomean                          635.4        559.0       -12.02%
¹ all samples are equal

benchstat results of the version with one string and this version:

                               │ new-binary-share-everything.txt │    new-binary-no-share-string.txt    │
                               │             sec/op              │    sec/op     vs base                │
HstoreScan/databasesql.Scan-10                       81.80µ ± 1%    81.71µ ± 2%        ~ (p=1.000 n=10)
HstoreScan/text-10                                   82.77µ ± 1%    82.45µ ± 1%        ~ (p=0.063 n=10)
HstoreScan/binary-10                                 7.330µ ± 2%   10.124µ ± 1%  +38.13% (p=0.000 n=10)
geomean                                              36.75µ         40.86µ       +11.18%

                               │ new-binary-share-everything.txt │   new-binary-no-share-string.txt    │
                               │              B/op               │     B/op      vs base               │
HstoreScan/databasesql.Scan-10                      56.23Ki ± 0%   56.23Ki ± 0%       ~ (p=0.232 n=10)
HstoreScan/text-10                                  65.12Ki ± 0%   65.12Ki ± 0%       ~ (p=0.218 n=10)
HstoreScan/binary-10                                20.73Ki ± 0%   19.87Ki ± 0%  -4.11% (p=0.000 n=10)
geomean                                             42.34Ki        41.75Ki       -1.39%

                               │ new-binary-share-everything.txt │     new-binary-no-share-string.txt     │
                               │            allocs/op            │  allocs/op   vs base                   │
HstoreScan/databasesql.Scan-10                        744.0 ± 0%    744.0 ± 0%         ~ (p=1.000 n=10) ¹
HstoreScan/text-10                                    743.0 ± 0%    743.0 ± 0%         ~ (p=1.000 n=10) ¹
HstoreScan/binary-10                                  41.00 ± 0%   316.00 ± 0%  +670.73% (p=0.000 n=10)
geomean                                               283.0         559.0        +97.53%
¹ all samples are equal
2023-06-16 15:31:37 -05:00
Evan Jones
2de94187f5 hstore: Make binary parsing 2X faster
* use []string for value string pointers: one allocation instead of
  one per value.
* use one string for all key/value pairs, instead of one for each.

After this change, one Hstore will share two allocations: one string
and one []string. The disadvantage is that it cannot be deallocated
until all key/value pairs are unused. This means if an application
takes a single key or value from the Hstore and holds on to it, its
memory footprint will increase. I would guess this is an unlikely
problem, but it is possible.

The benchstat results from my M1 Max are below.

goos: darwin
goarch: arm64
pkg: github.com/jackc/pgx/v5/pgtype
                               │   orig.txt   │               new.txt               │
                               │    sec/op    │   sec/op     vs base                │
HstoreScan/databasesql.Scan-10    82.11µ ± 1%   82.66µ ± 2%        ~ (p=0.436 n=10)
HstoreScan/text-10                83.30µ ± 1%   84.24µ ± 3%        ~ (p=0.165 n=10)
HstoreScan/binary-10             15.987µ ± 2%   7.459µ ± 6%  -53.35% (p=0.000 n=10)
geomean                           47.82µ        37.31µ       -21.98%

                               │   orig.txt   │               new.txt               │
                               │     B/op     │     B/op      vs base               │
HstoreScan/databasesql.Scan-10   56.23Ki ± 0%   56.23Ki ± 0%       ~ (p=0.324 n=10)
HstoreScan/text-10               65.12Ki ± 0%   65.12Ki ± 0%       ~ (p=0.675 n=10)
HstoreScan/binary-10             21.09Ki ± 0%   20.73Ki ± 0%  -1.70% (p=0.000 n=10)
geomean                          42.58Ki        42.34Ki       -0.57%

                               │  orig.txt   │               new.txt                │
                               │  allocs/op  │ allocs/op   vs base                  │
HstoreScan/databasesql.Scan-10    744.0 ± 0%   744.0 ± 0%        ~ (p=1.000 n=10) ¹
HstoreScan/text-10                743.0 ± 0%   743.0 ± 0%        ~ (p=1.000 n=10) ¹
HstoreScan/binary-10             464.00 ± 0%   41.00 ± 0%  -91.16% (p=0.000 n=10)
geomean                           635.4        283.0       -55.46%
¹ all samples are equal
2023-06-16 15:31:37 -05:00
Evan Jones
07670dddca do not share the original input string
This allows the original input string to be garbage collected, so it
should not change the memory footprint. This is a slower than the
version that shares a string, but only a small amount. It is still
faster than binary parsing (until that is optimized).

benchstat difference of original versus this version:

                               │  orig.txt   │     new-do-not-share-string.txt     │
                               │   sec/op    │   sec/op     vs base                │
HstoreScan/databasesql.Scan-10   82.11µ ± 1%   14.24µ ± 2%  -82.66% (p=0.000 n=10)
HstoreScan/text-10               83.30µ ± 1%   14.97µ ± 1%  -82.03% (p=0.000 n=10)
HstoreScan/binary-10             15.99µ ± 2%   15.80µ ± 0%   -1.16% (p=0.024 n=10)
geomean                          47.82µ        14.99µ       -68.66%

                               │   orig.txt   │     new-do-not-share-string.txt      │
                               │     B/op     │     B/op      vs base                │
HstoreScan/databasesql.Scan-10   56.23Ki ± 0%   20.11Ki ± 0%  -64.24% (p=0.000 n=10)
HstoreScan/text-10               65.12Ki ± 0%   29.00Ki ± 0%  -55.47% (p=0.000 n=10)
HstoreScan/binary-10             21.09Ki ± 0%   21.09Ki ± 0%        ~ (p=0.722 n=10)
geomean                          42.58Ki        23.08Ki       -45.80%

                               │  orig.txt  │     new-do-not-share-string.txt      │
                               │ allocs/op  │ allocs/op   vs base                  │
HstoreScan/databasesql.Scan-10   744.0 ± 0%   340.0 ± 0%  -54.30% (p=0.000 n=10)
HstoreScan/text-10               743.0 ± 0%   339.0 ± 0%  -54.37% (p=0.000 n=10)
HstoreScan/binary-10             464.0 ± 0%   464.0 ± 0%        ~ (p=1.000 n=10) ¹
geomean                          635.4        376.8       -40.70%
¹ all samples are equal

benchstat difference of the shared string versus not:

                               │ new-share-string.txt │     new-do-not-share-string.txt     │
                               │        sec/op        │   sec/op     vs base                │
HstoreScan/databasesql.Scan-10            10.57µ ± 2%   14.24µ ± 2%  +34.69% (p=0.000 n=10)
HstoreScan/text-10                        11.60µ ± 2%   14.97µ ± 1%  +29.03% (p=0.000 n=10)
HstoreScan/binary-10                      15.87µ ± 2%   15.80µ ± 0%        ~ (p=0.280 n=10)
geomean                                   12.48µ        14.99µ       +20.07%

                               │ new-share-string.txt │     new-do-not-share-string.txt      │
                               │         B/op         │     B/op      vs base                │
HstoreScan/databasesql.Scan-10           11.68Ki ± 0%   20.11Ki ± 0%  +72.17% (p=0.000 n=10)
HstoreScan/text-10                       20.58Ki ± 0%   29.00Ki ± 0%  +40.93% (p=0.000 n=10)
HstoreScan/binary-10                     21.08Ki ± 0%   21.09Ki ± 0%        ~ (p=0.427 n=10)
geomean                                  17.17Ki        23.08Ki       +34.39%

                               │ new-share-string.txt │      new-do-not-share-string.txt       │
                               │      allocs/op       │  allocs/op   vs base                   │
HstoreScan/databasesql.Scan-10             44.00 ± 0%   340.00 ± 0%  +672.73% (p=0.000 n=10)
HstoreScan/text-10                         44.00 ± 0%   339.00 ± 0%  +670.45% (p=0.000 n=10)
HstoreScan/binary-10                       464.0 ± 0%    464.0 ± 0%         ~ (p=1.000 n=10) ¹
geomean                                    96.49         376.8       +290.47%
2023-06-16 15:30:54 -05:00
Evan Jones
d48d36dc02 pgtype/hstore: Make text parsing about 6X faster
I am working on an application that uses hstore types, and we found
that returning the values is slow, particularly when using the text
protocol, such as when using database/sql. This improves parsing to
be about 6X faster (currently faster than binary). The changes are:

* referencing the original string instead of copying into new strings
  (very large win)
* using string.IndexByte to scan double quoted strings: it has
  architecture-specific assembly implementations, and most of the
  time is spent in key/value strings.
* estimating the number of key/value pairs to allocate the correct
  size of the slice and map up front. This reduces the number of
  allocations and bytes allocated by a factor of 2, and was a small
  CPU win.
* parsing directly into the Hstore, rather than copying into it.

This parser is stricter than the old one. It only accepts hstore
strings serialized by Postgres. The old one was already stricter
than Postgres's own parser, but previously accepted any whitespace
character after a comma. This one only accepts space. Example:

  "k1"=>"v1",\t"k2"=>"v2"

Postgres only ever uses ", " as the separator. See hstore_out:
https://github.com/postgres/postgres/blob/master/contrib/hstore/hstore_io.c

The result of using benchstat to compare the benchmark on my M1 Pro
with the following command line in below. The new text parser is now
faster than the binary parser. I will improve the binary parser in a
separate change.

for i in $(seq 10); do go test ./pgtype -run=none -bench=BenchmarkHstoreScan -benchtime=1s >> new.txt; done

goos: darwin
goarch: arm64
pkg: github.com/jackc/pgx/v5/pgtype
                               │  orig.txt   │               new.txt               │
                               │   sec/op    │   sec/op     vs base                │
HstoreScan/databasesql.Scan-10   82.11µ ± 1%   10.51µ ± 0%  -87.20% (p=0.000 n=10)
HstoreScan/text-10               83.30µ ± 1%   11.49µ ± 1%  -86.20% (p=0.000 n=10)
HstoreScan/binary-10             15.99µ ± 2%   15.77µ ± 1%   -1.35% (p=0.007 n=10)
geomean                          47.82µ        12.40µ       -74.08%

                               │   orig.txt   │               new.txt                │
                               │     B/op     │     B/op      vs base                │
HstoreScan/databasesql.Scan-10   56.23Ki ± 0%   11.68Ki ± 0%  -79.23% (p=0.000 n=10)
HstoreScan/text-10               65.12Ki ± 0%   20.58Ki ± 0%  -68.40% (p=0.000 n=10)
HstoreScan/binary-10             21.09Ki ± 0%   21.09Ki ± 0%        ~ (p=0.378 n=10)
geomean                          42.58Ki        17.18Ki       -59.66%

                               │  orig.txt   │               new.txt                │
                               │  allocs/op  │ allocs/op   vs base                  │
HstoreScan/databasesql.Scan-10   744.00 ± 0%   44.00 ± 0%  -94.09% (p=0.000 n=10)
HstoreScan/text-10               743.00 ± 0%   44.00 ± 0%  -94.08% (p=0.000 n=10)
HstoreScan/binary-10              464.0 ± 0%   464.0 ± 0%        ~ (p=1.000 n=10) ¹
geomean                           635.4        96.49       -84.81%
¹ all samples are equal
2023-06-16 15:30:54 -05:00
Nicola Murino
eb2807bda5 copy tests: test all supported modes
if we run the test in parallel, we always test the latest mode

see here

https://github.com/golang/go/wiki/LoopvarExperiment

also fix a lint warning about pgtype.Vec2
2023-06-15 20:54:24 -05:00
Nicola Murino
b1f8055584 TestConnectWithFallback: increase timeout
on Windows connecting on a closed port takes about 2 seconds.
You can test with something like this

        start := time.Now()
	_, err := d.DialContext(context.Background(), "tcp", "127.0.0.1:1")
	fmt.Printf("finished, time %s, err: %v\n", time.Since(start), err)

This seems by design

https://groups.google.com/g/comp.os.ms-windows.programmer.win32/c/jV6kRVY3BqM

Generally TestConnectWithFallback takes about 8-9 seconds on Windows.
Increase timeout to avoid random failures under load
2023-06-15 20:54:24 -05:00
Jack Christensen
461b9fa36e Release v5.4.0 2023-06-14 09:41:17 -05:00
Jack Christensen
45520d5a11 Document pgtype.Map and pgtype.Type are immutable after registration 2023-06-14 08:27:04 -05:00
Lev Zakharov
90f9aad67f add singleton pgtype.Map for default type mappings 2023-06-14 08:21:28 -05:00
Jack Christensen
5f28621394 Add docs clarifying that FieldDescriptions may return nil
https://github.com/jackc/pgx/issues/1634
2023-06-14 07:42:11 -05:00
Klaus
c542df4fb4 added MarshalJSON and UnmarshalJSON to timestamp and added their tests (based on timestamptz implementation) 2023-06-12 09:52:49 -05:00
Jack Christensen
34eddf9983 Increase slowWriteTimer to 15ms and document why 2023-06-12 09:39:26 -05:00
Jack Christensen
5d4f9018bf failed to write startup message error should be normalized 2023-06-12 09:39:26 -05:00
Jack Christensen
482e56a79b Fix race condition when CopyFrom is cancelled. 2023-06-12 09:39:26 -05:00
Jack Christensen
3ea2f57d8b Deprecate CheckConn in favor of Ping 2023-06-12 09:39:26 -05:00
Jack Christensen
26c79eb215 Handle writes that could deadlock with reads from the server
This commit adds a background reader that can optionally buffer reads.
It is used whenever a potentially blocking write is made to the server.
The background reader is started on a slight delay so there should be no
meaningful performance impact as it doesn't run for quick queries and
its overhead is minimal relative to slower queries.
2023-06-12 09:39:26 -05:00
Jack Christensen
85136a8efe Restore pgx v4 style CopyFrom implementation
This approach uses an extra goroutine to write while the main goroutine
continues to read. This avoids the need to use non-blocking I/O.
2023-06-12 09:39:26 -05:00
Jack Christensen
4410fc0a65 Remove nbconn
The non-blocking IO system was designed to solve three problems:

1. Deadlock that can occur when both sides of a connection are blocked
   writing because all buffers between are full.
2. The inability to use a write deadline with a TLS.Conn without killing
   the connection.
3. Efficiently check if a connection has been closed before writing.
   This reduces the cases where the application doesn't know if a query
   that does a INSERT/UPDATE/DELETE was actually sent to the server or
   not.

However, the nbconn package is extraordinarily complex, has been a
source of very tricky bugs, and has OS specific code paths. It also does
not work at all with underlying net.Conn implementations that do not
have platform specific non-blocking IO syscall support and do not
properly implement deadlines. In particular, this is the case with
golang.org/x/crypto/ssh.

I believe the deadlock problem can be solved with a combination of a
goroutine for CopyFrom like v4 used and a watchdog for regular queries
that uses time.AfterFunc.

The write deadline problem actually should be ignorable. We check for
context cancellation before sending a query and the actual Write should
be almost instant as long as the underlying connection is not blocked.
(We should only have to wait until it is accepted by the OS, not until
it is fully sent.)

Efficiently checking if a connection has been closed is probably the
hardest to solve without non-blocking reads. However, the existing code
only solves part of the problem. It can detect a closed or broken
connection the OS knows about, but it won't actually detect other types
of broken connections such as a network interruption. This is currently
implemented in CheckConn and called automatically when checking a
connection out of the pool that has been idle for over one second. I
think that changing CheckConn to a very short deadline read and changing
the pool to do an actual Ping would be an acceptable solution.

Remove nbconn and non-blocking code. This does not leave the system in
an entirely working state. In particular, CopyFrom is broken, deadlocks
can occur for extremely large queries or batches, and PgConn.CheckConn
is now a `select 1` ping. These will be resolved in subsequent commits.
2023-06-12 09:39:26 -05:00
Massimo Costa
9cfdd21f1c feat: add reference to pgx-slog adapter 2023-06-12 09:37:20 -05:00
Evan Jones
4d643b75f5 pgtype/hstore_test.go: Extend coverage of scan benchmark
I am working on an application that reads a lot of hstore values, and
have discovered that scanning it is fairly slow. I'm working on some
improvements, but first I wanted a better benchmark. This adds more
realistic data, and extends it to cover the three APIs: database/sql,
and pgconn.Rows.Scan with both text and binary protocols.
2023-06-12 09:17:24 -05:00
Jack Christensen
490f70fc5f Fix docs for QueryExecModeDescribeExec with connection poolers
https://github.com/jackc/pgx/issues/1635
2023-06-11 08:26:02 -05:00
Evan Jones
1b68b5970e pgtype/hstore: Save 2 allocs in database/sql Scan implementation
Remove unneeded string to []byte to string conversion, which saves 2
allocs and should make Hstore text scanning slightly faster.

The Hstore.Scan() function takes a string as input, converts it to
[]byte, and calls scanPlanTextAnyToHstoreScanner.Scan(). That
function converts []byte back to string and calls parseHstore. This
refactors scanPlanTextAnyToHstoreScanner.Scan into
scanPlanTextAnyToHstoreScanner.scanString so the database/sql Scan
function can call it directly, bypassing this conversion.

The added Benchmark shows this saves 2 allocs for longer strings, and
saves about 5% CPU overall on my M1 Pro. benchstat output:

goos: darwin
goarch: arm64
pkg: github.com/jackc/pgx/v5/pgtype
              │  orig.txt   │              new.txt               │
              │   sec/op    │   sec/op     vs base               │
HstoreScan-10   1.334µ ± 2%   1.257µ ± 2%  -5.77% (p=0.000 n=10)

              │   orig.txt   │               new.txt               │
              │     B/op     │     B/op      vs base               │
HstoreScan-10   2.094Ki ± 0%   1.969Ki ± 0%  -5.97% (p=0.000 n=10)

              │  orig.txt  │              new.txt              │
              │ allocs/op  │ allocs/op   vs base               │
HstoreScan-10   36.00 ± 0%   34.00 ± 0%  -5.56% (p=0.000 n=10)
2023-06-07 15:35:22 -05:00
Evan Jones
ee04d4a74d pgtype/hstore: Avoid Postgres Mac OS X parsing bug
Postgres on Mac OS X has a bug in how it parses hstore text values
that causes it to misinterpret some Unicode values as spaces. This
causes values sent by pgx to be misinterpreted. To avoid this, always
quote hstore values, which is how Postgres serializes them itself.
The test change fails on Mac OS X without this fix.

While I suspect this should not be performance critical for any
application, I added a quick benchmark to test the performance of the
encoding. This change actually makes encoding slightly faster on my
M1 Pro. The output from the benchstat program on this banchmark is:

goos: darwin
goarch: arm64
pkg: github.com/jackc/pgx/v5/pgtype
                          │   orig.txt   │           new-quotes.txt            │
                          │    sec/op    │   sec/op     vs base                │
HstoreSerialize/text-10      207.1n ± 0%   142.3n ± 1%  -31.31% (p=0.000 n=10)
HstoreSerialize/binary-10   100.10n ± 0%   99.64n ± 1%   -0.45% (p=0.013 n=10)
geomean                      144.0n        119.1n       -17.31%

I have also attempted to fix the Postgres bug, but it will take a
long time for this fix to get upstream:

https://www.postgresql.org/message-id/CA%2BHWA9awUW0%2BRV_gO9r1ABZwGoZxPztcJxPy8vMFSTbTfi4jig%40mail.gmail.com
2023-06-07 15:29:25 -05:00
Jack Christensen
d9560c78b8 Use tx instead of underlying conn in test
Improves clarity
2023-06-03 07:59:28 -05:00
Jack Christensen
608f39f426 Ensure pgxpool.Pool.QueryRow.Scan releases connection on panic
Otherwise a connection would be leaked and closing the pool would block.

https://github.com/jackc/pgx/issues/1628
2023-06-03 07:39:56 -05:00
Nicola Murino
229d2aaa49 TestConnCopyFromBinary: increase context timeout 2023-06-03 06:45:28 -05:00
Nicola Murino
b4314ddaf7 TestConnCopyFromSlowFailRace: increase context timeout
On Windows time.Sleep(time.Millisecond) will sleep for 15 milliseconds
2023-06-03 06:45:28 -05:00
Nicola Murino
28bd5b3843 TestConnectTimeoutStuckOnTLSHandshake: allow more time to complete
to avoid random errors in Windows CI
2023-06-03 06:45:28 -05:00
Nicola Murino
fb47e1abbb TestContextWatcherStress: reduce sleep counts 2023-06-03 06:45:28 -05:00
Nicola Murino
c861bce438 CancelRequest: don't try to read the reply
Postgres will just process the request and close the connection
2023-06-03 06:45:28 -05:00
Nicola Murino
46d91255b0 remove timeout for test cases on Windows 2023-06-03 06:45:28 -05:00
Nicola Murino
ef363b59ab skipping some config parsing tests on Windows
this should be investigated and fixed
2023-06-03 06:45:28 -05:00
Nicola Murino
bad6b36c47 CI Windows: Initialize test database 2023-06-03 06:45:28 -05:00
Nicola Murino
33d4fa0fa6 TLS with Fake Non-blocking IO test is expected to fail on Windows 2023-06-03 06:45:28 -05:00
Nicola Murino
30d63caa6a CI: run basic tests on Windows 2023-06-03 06:45:28 -05:00
Nicola Murino
b0fa429fd0 add a comment explaining that nbOperMu and nbOperCnt are used on Windows 2023-06-03 06:45:28 -05:00
Nicola Murino
32c7858e61 Revert "Remove unused fields"
This reverts commit 2c1973de4634a6a83d3ba09bdcde392aaf7cfb71.
2023-06-03 06:45:28 -05:00
Pavlo Golub
c7733fe52e Update README.md
add pgxmock description
2023-05-31 07:11:41 -05:00
Jack Christensen
9720d0d63f Use context timeouts for tracelog tests 2023-05-29 11:23:21 -05:00
Jack Christensen
5f6636d028 Add context timeouts for more pgxpool tests 2023-05-29 11:15:40 -05:00
Jack Christensen
a1a97a7ca8 Add context timeouts for some pgxpool tests 2023-05-29 11:04:52 -05:00
Jack Christensen
0ec512b504 Fix: possible fail in goroutine after test has completed 2023-05-29 10:43:15 -05:00
Jack Christensen
f93b42b6ac Allow more time for TestConnExecBatchHuge 2023-05-29 10:35:38 -05:00
Jack Christensen
9f00b6f750 Use context timeouts in more tests
Tests should timeout in a reasonable time if something is stuck. In
particular this is important when testing deadlock conditions such as
can occur with the copy protocol if both the client and the server are
blocked writing until the other side does a read.
2023-05-29 10:25:57 -05:00
Jonathan Gonzalez V
4b9aa7c4f2 chore: update version of golang.org/x/crypto library from v0.6.0 to v0.9.0
During the update also the following packages were updated:

golang.org/x/sys v0.5.0 to v0.8.0
golang.org/x/text v0.7.0 to v0.9.0

Signed-off-by: Jonathan Gonzalez V <jonathan.abdiel@gmail.com>
2023-05-29 09:20:51 -05:00
Jack Christensen
2c1973de46 Remove unused fields 2023-05-27 08:18:47 -05:00
Jack Christensen
b3739c1289 pgconn.CheckConn locks connection
This ensures that a closed connection at the pgconn layer is not
considered okay when the background closing of the net.Conn is still in
progress.

This also means that CheckConn cannot be called when the connection is
locked (for example, by in an progress query). But that seems
reasonable. It's not exactly clear that that would have ever worked
anyway.

https://github.com/jackc/pgx/issues/1618#issuecomment-1563702231
2023-05-26 06:03:25 -05:00
Alek Anokhin
70a200cff4 Fix test failures
Add bool type alias conversion in `elemKindToPointerTypes` and `underlyingNumberType`
2023-05-20 08:53:23 -05:00
Wichert Akkerman
c1c67e4e58 Fix: correctly handle bool type aliases
https://github.com/jackc/pgx/issue/1593
2023-05-20 08:53:23 -05:00
Evan Jones
9de41fac75 ParseConfig: default_query_exec_mode: Return arg in error
If the default_query_exec_mode is unknown, the returned error
previously was:

    invalid default_query_exec_mode: <nil>

This changes it to return the argument. Add a test that unknown modes
fail to parse and include this string.
2023-05-20 08:09:35 -05:00
Evan Jones
11d892dfcf pgconn.CancelRequest: Fix unix sockets: don't use RemoteAddr()
The tests for cancelling requests were failing when using unix
sockets. The reason is that net.Conn.RemoteAddr() calls getpeername()
to get the address. For Unix sockets, this returns the address that
was passed to bind() by the *server* process, not the address that
was passed to connect() by the *client*. For postgres, this is always
relative to the server's directory, so is a path like:

    ./.s.PGSQL.5432

Since it does not return the full absolute path, this function cannot
connect, so it cannot cancel requests. To fix it, use the connection's
config for Unix sockets. I think this should be okay, since a system
using unix sockets should not have "fallbacks". If that is incorrect,
we will need to save the address on PgConn.

Fixes the following failed tests when using Unix sockets:

--- FAIL: TestConnCancelRequest (2.00s)
    pgconn_test.go:2056:
          Error Trace:  /Users/evan.jones/pgx/pgconn/pgconn_test.go:2056
                              /Users/evan.jones/pgx/pgconn/asm_arm64.s:1172
          Error:        Received unexpected error:
                        dial unix ./.s.PGSQL.5432: connect: no such file or directory
          Test:         TestConnCancelRequest
    pgconn_test.go:2063:
          Error Trace:  /Users/evan.jones/pgx/pgconn/pgconn_test.go:2063
          Error:        Object expected to be of type *pgconn.PgError, but was <nil>
          Test:         TestConnCancelRequest
--- FAIL: TestConnContextCanceledCancelsRunningQueryOnServer (5.10s)
    pgconn_test.go:2109:
          Error Trace:  /Users/evan.jones/pgx/pgconn/pgconn_test.go:2109
          Error:        Received unexpected error:
                        timeout: context already done: context deadline exceeded
          Test:         TestConnContextCanceledCancelsRunningQueryOnServer
2023-05-20 08:08:47 -05:00
Evan Jones
0292edecb0 pgx.Conn: Fix memory leak: Delete items from preparedStatements
Previously, items were never removed from the preparedStatements map.
This means workloads that send a large number of unique queries could
run out of memory. Delete items from the map when sending the
deallocate command to Postgres. Add a test to verify this works.

Fixes https://github.com/jackc/pgx/issues/1456
2023-05-20 08:06:37 -05:00
Evan Jones
eab316e200 pgtype.Hstore: Fix quoting of whitespace; Add test
Before this change, the Hstore text protocol did not quote keys or
values containing non-space whitespace ("\r\n\v\t"). This causes
inserts with these values to fail with errors like:

    ERROR: Syntax error near "r" at position 17 (SQLSTATE XX000)

The previous version also quoted curly braces ("{}"), but they don't
seem to require quoting.

It is possible that it would be easier to just always quote the
values, which is what Postgres does when encoding its text protocol,
but this is a smaller change.
2023-05-16 07:02:55 -05:00
Evan Jones
8ceef73b84 pgtype.parseHstore: Reject invalid input; Fix error messages
The parseHstore function did not check the return value from
p.Consume() after a ', ' sequence. It expects a doublequote '"' that
starts the next key, but would accept any character. This means it
accepted invalid input such as:

    "key1"=>"b", ,key2"=>"value"

Add a unit test that covers this case
Fix a couple of the nearby error strings while looking at this.

Found by looking at staticcheck warnings:

    pgtype/hstore.go:434:6: this value of end is never used (SA4006)
    pgtype/hstore.go:434:6: this value of r is never used (SA4006)
2023-05-15 18:10:20 -05:00
Evan Jones
bbcc4fc0b8 pgtype/hstore_test.go: Add coverage for text protocol
The existing test registers pgtype.Hstore in the text map, then uses
the query modes that use the binary protocol. The existing test did
not use the text parsing code. Add a version of the test that uses
pgtype.Hstore as the input and output argument in all query modes,
and tests it without registering the codec.
2023-05-15 18:09:31 -05:00
Evan Cordell
cead918e18 run tests that rely on backend PID to run against cockroach
cockroach has supported backend PIDs on connections since 22.1:
https://www.cockroachlabs.com/docs/releases/v22.1.html#v22-1-3-sql-language-changes
2023-05-15 18:06:08 -05:00
Evan Cordell
7f2bb9595f add BeforeClose to pgxpool.Pool 2023-05-15 18:06:08 -05:00
Evan Jones
d8b38b28be pgtype/hstore.go: Remove unused quoteHstore{Element,Replacer}
These are unused. The code uses quoteArrayElement instead.
2023-05-13 10:03:22 -05:00
Evan Jones
2a86501e86 Fix hstore NULL versus empty
When running queries with the hstore type registered, and with simple
mode queries, the scan implementation does not correctly distinguish
between NULL and empty. Fix the implementation and add a test to
verify this.
2023-05-13 09:34:30 -05:00
Jack Christensen
f59e8bf555 Fix: RowToStructByPos with embedded unexported struct
https://github.com/jackc/pgx/issues/1583
2023-04-27 21:03:58 -05:00
Lev Zakharov
c27b9b49ea support different bool string representations 2023-04-27 20:29:41 -05:00
Jack Christensen
6defa2a607 Fix error when using BatchResults.Exec
...on a select that returns an error after some rows.

This was initially found in by a failure with CockroachDB because it
seems to send a RowDescription before an error even when no rows are
returned. PostgreSQL doesn't.
2023-04-20 21:43:59 -05:00
Jack Christensen
a23a423f55 Fix pipelineBatchResults.Exec() not returning error from ResultReader 2023-04-20 21:19:41 -05:00
Jack Christensen
09371981f9 Fix pipeline batch results not closing pipeline
when error occurs while reading directly from results instead of using
a callback.

https://github.com/jackc/pgx/issues/1578
2023-04-20 20:58:04 -05:00
Jack Christensen
67f2a41587 Fix scanning a table type into a struct
Table types have system / hidden columns like tableoid, cmax, xmax, etc.
These are not included when sending or receiving composite types.

https://github.com/jackc/pgx/issues/1576
2023-04-20 20:13:37 -05:00
Simon Paredes
2cf1541bb9 wrap error 2023-04-11 18:07:05 -05:00
Vinícius Garcia
84eb2e460a Add KSQL on the 3rd party section of the README 2023-04-11 17:53:38 -05:00
Jack Christensen
847f888631 Fix scan array of record to pointer to slice of struct
https://github.com/jackc/pgx/issues/1570
2023-04-08 14:39:48 -05:00
Daniel Castro
f72a147db3 skip cockroachdb 2023-04-05 17:36:00 -05:00
Daniel Castro
8b7c699b8f proper naming 2023-04-05 17:36:00 -05:00
Daniel Castro
215ffafc74 fix tests 2023-04-05 17:36:00 -05:00
Daniel Castro
5eeaa201d9 add extra tests 2023-04-05 17:36:00 -05:00
Jack Christensen
be79f1c8f5 Allow batch callback function to override error
https://github.com/jackc/pgx/pull/1538#issuecomment-1486083411
2023-03-31 20:18:05 -05:00
cemre.mengu
ca022267db add tests 2023-03-25 10:22:11 -05:00
Cemre Mengu
2a653b4a8d fix: handle null interface for json
When using `scany` I encountered the following case. This seems to fix it.

Looks like null `jsonb` columns cause the problem. If you create a table like below you can see that the following code fails. Is this expected?

```sql
CREATE TABLE test (
	a int4 NULL,
	b int4 NULL,
	c jsonb NULL
);

INSERT INTO test (a, b, c) VALUES (1, null, null);
```

```go
package main

import (
	"context"
	"log"

	"github.com/georgysavva/scany/v2/pgxscan"
	"github.com/jackc/pgx/v5"
)

func main() {
	var rows []map[string]interface{}
	conn, _ := pgx.Connect(context.Background(), , ts.PGURL().String())
	
	// this will fail with can't scan into dest[0]: cannot scan NULL into *interface {}
	err := pgxscan.Select(context.Background(), conn, &rows, `SELECT c from test`) 
	
	// this works
	// err = pgxscan.Select(context.Background(), conn, &rows, `SELECT a,b from test`)
	
	if err != nil {
		panic(err)
	}

	log.Printf("%+v", rows)
}
```
2023-03-25 10:22:11 -05:00
Jack Christensen
7af80ae8a6 Batch Query callback is called even when there is an error
This allows the callback to handle additional error types such as
foreign key constraint violations.

See https://github.com/jackc/pgx/pull/1538.
2023-03-25 10:21:34 -05:00
Audi P. Risa P
7555c43033 add lax field to namedStructRowScanner 2023-03-25 09:57:38 -05:00
Audi P. Risa P
193bab416f add RowTo(AddrOf)StructByNameLax 2023-03-25 09:57:38 -05:00
Dmitry K
e9d64ec29d Use time.Equal instead of direct comparison 2023-03-24 17:51:34 -05:00
Dmitry K
2f1bba09c4 Guard deadline readings by mutex 2023-03-24 17:51:34 -05:00
Dmitry K
d829073b2f Improve deadline simulation 2023-03-24 17:51:34 -05:00
Dmitry K
48da6435a5 Add deadline simulation 2023-03-24 17:51:34 -05:00
Dmitry K
34e3013153 Remove commented out atomic calls 2023-03-24 17:51:34 -05:00
Dmitry K
009a377028 Use mutex to guard entire SetBlockingMode call 2023-03-24 17:51:34 -05:00
Dmitry K
e05abb83ec Better error messages 2023-03-24 17:51:34 -05:00
Dmitry K
89475c4c91 use atomic.Int32 instead of int + atomic calls 2023-03-24 17:51:34 -05:00
Dmitry K
c3d62c8783 Small comment update 2 2023-03-24 17:51:34 -05:00
Dmitry K
1298a835bc Small comment update 2023-03-24 17:51:34 -05:00
Dmitry K
b2b4fbcf57 Set socket to non-blocking mode in Read, Flush and BufferReadUntilBlock operations 2023-03-24 17:51:34 -05:00
Dmitry K
3db7d1774e Set socket to non-blocking mode before doneChan is allocated to avoid that channel leaked in case when SetBlockingMode will return error 2023-03-24 17:51:34 -05:00
Dmitry K
a83faa67f5 Small improvements 2023-03-24 17:51:34 -05:00
Dmitry K
8b5e8d9d89 Fix Windows non-blocking I/O for CopyFrom
Created based on discussion here: https://github.com/jackc/pgx/pull/1525#pullrequestreview-1344511991

Fixes https://github.com/jackc/pgx/issues/1552
2023-03-24 17:51:34 -05:00
Sergej Brazdeikis
9ae852eb58 Fix typo in error message uint32 -> uint16 2023-03-11 15:34:08 -06:00
Nicola Murino
19039e6dd1 fix build on 32-bit Windows 2023-03-07 17:09:03 -06:00
Dmitry K
0dbb0a52ab Fix realNonblockingRead, set realNonblockingRead call error to nonblockReadErr 2023-03-04 09:25:36 -06:00
Dmitry K
087b8b2ba8 Try to make windows non-blocking I/O 2023-03-04 09:25:36 -06:00
Jack Christensen
c09ddaf440 Add Windows non-blocking IO 2023-03-04 09:25:36 -06:00
Jack Christensen
80eb6e1859 Remove sleeps in test
Sleeping for a microsecond on Windows actually takes 10ms. This caused
the test to never finish. Instead use channel to ensure the two
goroutines start working at the same time and remove the sleeps.
2023-02-27 20:32:51 -06:00
Jack Christensen
7ec6ee7b0a Release v5.3.1 2023-02-27 19:57:26 -06:00
Jack Christensen
6105ca5073 Fix TestInternalNonBlockingWriteWithDeadline(t
The test was relying on sending so big a message that the write blocked.
However, it appears that on Windows the TCP connections over localhost
have an very large or infinite sized buffer. Change the test to simply
set the deadline to the current time before triggering the write.
2023-02-25 17:02:55 -06:00
Jack Christensen
8f46c75e73 Fix: fake non-blocking read adaptive wait time
If the time reached the minimum time before the 5 tries were up it
would get stuck reading 1 byte at a time indefinitely.
2023-02-25 16:45:34 -06:00
Jack Christensen
38e09bda4c Fix *wrapSliceEncodePlan[T].Encode
It should pass a FlatArray[T] to the next step instead of a
anySliceArrayReflect. By using a anySliceArrayReflect, an encode of
[]github.com/google/uuid.UUID followed by []string into a PostgreSQL
uuid[] would crash. This was caused by a EncodePlan cache collision
where the second encoding used part of the cached plan of the first.

In proper usage a cache collision shouldn't be able to occur. If this
assertion proves incorrect it will be necessary to add an optional
interface to ScanPlan and EncodePlan that marks the plan as ineligable
for caching. But I have been unable to construct a failing case, and
given that ScanPlans have been cached for quite some time now without
incident I do not think it is possible. This issue only occurred due to
the bug in *wrapSliceEncodePlan[T].Encode.

https://github.com/jackc/pgx/issues/1502
2023-02-21 21:04:30 -06:00
Ch. König
9567297815 add mgx module reference to the readme file 2023-02-17 08:58:34 -06:00
Jack Christensen
42d327f660 Add text format jsonpath support 2023-02-14 19:52:47 -06:00
Jack Christensen
f17c743c3c Unwatch at end of test
https://github.com/jackc/pgx/issues/1505
2023-02-14 09:03:41 -06:00
Jack Christensen
a6ace8969b Fix: Prefer sql.Scanner before TryWrapScanPlanFuncs
This was already the case when the data type was unknown but should also
be the case when it is known.
2023-02-14 09:03:41 -06:00
Tomáš Procházka
c2e278e5d4 simplify duplicate pgx registration guard
The binary search is overkill here.
Readability first.
2023-02-13 21:08:42 -06:00
Jack Christensen
c5daa3a814 Release v5.3.0 2023-02-11 09:15:31 -06:00
Jack Christensen
f5d2da7a19 Upgrade golang.org/x/crypto and golang.org/x/text 2023-02-11 08:59:51 -06:00
Jack Christensen
b8262ace75 Upgrade to puddle v2.2.0 2023-02-11 08:57:19 -06:00
Jack Christensen
2100a64dbe Fix broken benchmarks 2023-02-10 20:26:18 -06:00
Jack Christensen
4484831550 Prefer binary format for arrays
This improves performance decoding text[].
2023-02-10 20:21:25 -06:00
Jack Christensen
1f43e2e490 Fix text format array decoding with a string of "NULL"
It was incorrectly being treated as NULL instead of 'NULL'.

fixes https://github.com/jackc/pgx/issues/1494
2023-02-10 19:59:03 -06:00
Jack Christensen
b707faea8f Fix flickering test TestBufferNonBlockingRead 2023-02-10 19:40:31 -06:00
Vitalii Solodilov
255f16b00f Register pgx driver using major version
Fixed: #1480
2023-02-10 19:18:45 -06:00
Felix Röhrich
a47e836471 make TestPointerPointerStructScan easier to read 2023-02-10 19:06:20 -06:00
Felix Röhrich
5cd8468b99 replace erroneous reflect.New with reflect.Zero in TryWrapStructScanPlan 2023-02-10 19:06:20 -06:00
Felix Röhrich
fa5fbed497 add filter for dropped attributes in getCompositeType 2023-02-07 08:45:56 -06:00
Jack Christensen
190c05cc24 CI fix: Go versions are strings
Otherwise Go 1.20 was being treated as Go 1.2.
2023-02-04 07:32:13 -06:00
Jack Christensen
c875abea84 Fix encode []any to array
https://github.com/jackc/pgx/issues/1488
2023-02-04 07:28:52 -06:00
Jack Christensen
98543e0354 Update supported Go versions and add 1.20 to CI 2023-02-04 07:01:03 -06:00
Jack Christensen
32c29a6edd Update issue template to use pgx v5 2023-02-01 19:40:25 -06:00
Jack Christensen
9963c32d4f Only count when bytes actually read 2023-01-31 20:35:44 -06:00
Jack Christensen
6bc327b3ce Find fastest possible read time for fakeNonblockingReadWaitDuration
The first 5 fake non-blocking reads are limited to 1 byte. This should
ensure that there is a measurement of a read where bytes are already
waiting in Go or the OS's read buffer.
2023-01-31 20:25:57 -06:00
Jack Christensen
f46d35610e Only set c.fakeNonblockingReadWaitDuration when it will be decreased 2023-01-31 20:25:17 -06:00
Jack Christensen
cf78472ce5 Use unix build tag
With Go 1.19 available we can use a simpler build tag.
2023-01-31 20:10:34 -06:00
Yumin Xia
766d2bba4f add UnmarshalJSON for pgtype Numeric 2023-01-30 21:33:02 -06:00
Jack Christensen
384a581e99 Avoid slightly overflowing the send copy buffer
This avoids send buffer sequences such as 65531, 13, 65531, 13, 65531,
13, 65531, 13.
2023-01-30 20:59:54 -06:00
Jack Christensen
898891a6ee Fake non-blocking read adapts its max wait time
The reason for a high max wait time was to ensure that reads aren't
cancelled when there is data waiting for it in Go or the OS's receive
buffer. Unfortunately, there is no way to know ahead of time how long
this should take.

This new code uses 2x the fastest successful read time as the max read
time. This allows the code to adapt to whatever host it is running on.

https://github.com/jackc/pgx/issues/1481
2023-01-28 09:35:52 -06:00
Jack Christensen
7019ed1edf Fix tests for iobufpool optimization 2023-01-28 09:30:12 -06:00
Jack Christensen
eee854fb06 iobufpool uses *[]byte instead of []byte to reduce allocations 2023-01-28 08:02:49 -06:00
Jack Christensen
bc754291c1 Save memory on non blocking read path
Only create RawConn.Read callback once and have it use NetConn fields.
Avoids the closure and some allocations.

https://github.com/jackc/pgx/issues/1481
2023-01-27 20:53:30 -06:00
Jack Christensen
2c7d86a543 Only create RawConn.Write callback once
This saves an allocation on every call.

https://github.com/jackc/pgx/issues/1481
2023-01-27 20:34:21 -06:00
Jack Christensen
42a47194a2 Memoize encode plans
This significantly reduces memory allocations in paths that repeatedly
encode the same type of values such as CopyFrom.

https://github.com/jackc/pgx/issues/1481
2023-01-27 20:19:06 -06:00
Jack Christensen
7941518809 BufferReadUntilBlock should release buf when no bytes read
This was causing allocations every time there was a non-blocking read
with nothing to read.

https://github.com/jackc/pgx/issues/1481
2023-01-27 18:03:38 -06:00
Alexey Palazhchenko
f839d501a7 Apply gofmt -s
And add CI check for that.
2023-01-24 07:55:00 -06:00
Alexey Palazhchenko
f581584148 Use Go 1.19's lists for proper formatting 2023-01-23 19:54:30 -06:00
Jack Christensen
e48e7a7189 Fix scanning json column into **string
refs https://github.com/jackc/pgx/issues/1470
2023-01-20 18:38:11 -06:00
Mark Chambers
516300aabf spelling: successfully, compatibility 2023-01-16 20:06:01 -06:00
Mark Chambers
62a7e19a04 func multiInsert returns nil when err != nil
I suspect it should return err.
2023-01-16 20:06:01 -06:00
Mark Chambers
672431c0bd Replace deprecated "io/ioutil"
ioutil.TempFile: Deprecated: As of Go 1.17, this function simply calls os.CreateTemp.

ioutil.ReadFile: Deprecated: As of Go 1.16, this function simply calls os.ReadFile.
2023-01-16 20:06:01 -06:00
Mark Chambers
7c0c7dc01e Remove unused test struct. 2023-01-16 20:06:01 -06:00
Jack Christensen
fcec008a4c Update CI to test on Go 1.19 2023-01-14 09:37:11 -06:00
Jack Christensen
d993cfa8fd Use puddle with Go 1.19 atomics instead of uber atomics
Doing this a bit early to resolve
https://github.com/jackc/pgx/issues/1465. Won't actually tag the release
until Go 1.20 is released to comply with pgx's versioning policy.
2023-01-14 09:31:38 -06:00
Jack Christensen
a95cfe5cc5 Fix connect with multiple hostnames when one can't be resolved
If multiple hostnames are provided and one cannot be resolved the others
should still be tried.

Longterm, it would be nice for the connect process to return a list of
errors rather than just one.

fixes https://github.com/jackc/pgx/issues/1464
2023-01-14 09:19:00 -06:00
Mark Chambers
c46d792c93 Numeric numberTextBytes() workaround...
This seems a bit of a hack. It fixes the problems demonstrated in my previous commit.

Maybe there's a cleaner way?

Associated: https://github.com/jackc/pgx/issues/1426
2023-01-14 08:42:42 -06:00
Mark Chambers
37c6f97b11 pgtype.Numeric numberTextBytes() encoding bug
Demonstrate the problem with the tests:

...for negative decimal values e.g. -0.01

This causes errors when encoding to JSON:

    "json: error calling MarshalJSON for type pgtype.Numeric"

It also causes scan failures of sql.NullFloat64:

    "converting driver.Value type string ("0.-1") to a float64"

As reported here: https://github.com/jackc/pgx/issues/1426
2023-01-14 08:42:42 -06:00
Alex Goncharov
74f9b9f0a4 Bump github.com/jackc/pgservicefile to v0.0.0-20221227161230-091c0ba34f0a to get rid of vulnerable version of gopkg.in/yaml.v2
Signed-off-by: Alex Goncharov <github@b4bay.com>
2022-12-27 17:31:07 -06:00
Stephen Afam-Osemene
5177e1a8df Add stephenafamo/scan reference to README.md 2022-12-27 10:13:36 -06:00
Jack Christensen
d4fcd4a897 Support sql.Scanner on renamed base type
https://github.com/jackc/pgtype/issues/197
2022-12-23 14:22:59 -06:00
Wagner Camarao
c514b2e0c3 add pmx module reference to the readme file 2022-12-23 13:51:59 -06:00
Jack Christensen
e66ad1bcec Fix encode to json ignoring driver.Valuer
https://github.com/jackc/pgx/issues/1430
2022-12-23 13:44:09 -06:00
Alejandro Do Nascimento Mora
c4ac6d810f Use DefaultQueryExecMode in CopyFrom
CopyFrom had to create a prepared statement to get the OIDs of the data
types that were going to be copied into the table. Every COPY operation
required an extra round trips to retrieve the type information. There
was no way to customize this behavior.

By leveraging the QueryExecMode feature, like in `Conn.Query`, users can
specify if they want to cache the prepared statements, execute
them on every request (like the old behavior), or bypass the prepared
statement relying on the pgtype.Map to get the type information.

The `QueryExecMode` behave exactly like in `Conn.Query` in the way the
data type OIDs are fetched, meaning that:

- `QueryExecModeCacheStatement`: caches the statement.
- `QueryExecModeCacheDescribe`: caches the statement and assumes they do
  not change.
- `QueryExecModeDescribeExec`: gets the statement description on every
  execution. This is like to the old behavior of `CopyFrom`.
- `QueryExecModeExec` and `QueryExecModeSimpleProtocol`: maintain the
  same behavior as before, which is the same as `QueryExecModeDescribeExec`.
  It will keep getting the statement description on every execution

The `QueryExecMode` can only be set via
`ConnConfig.DefaultQueryExecMode`, unlike `Conn.Query` there's no
support for specifying the `QueryExecMode` via optional arguments
in the function signature.
2022-12-23 13:22:26 -06:00
Jack Christensen
456a242f5c Unregistered OIDs are handled the same as unknown OIDs
This improves handling of unregistered types. In general, they should
"just work". But there are performance benefits gained and some edge
cases avoided by registering types. Updated documentation to mention
this.

https://github.com/jackc/pgx/issues/1445
2022-12-23 13:14:56 -06:00
Jack Christensen
d737852654 Fix: driver.Value representation of bytea should be []byte not string
https://github.com/jackc/pgx/issues/1445
2022-12-21 17:54:42 -06:00
Ben Weintraub
29ad306e47 Make MaxConnLifetimeJitter setting actually jitter 2022-12-20 20:18:26 -06:00
Jack Christensen
f42af35884 Add support for single dimensional arrays
https://github.com/jackc/pgx/issues/1442
2022-12-20 20:12:12 -06:00
Yevgeny Pats
11fa083a0d fix: Improve errors in batch modes 2022-12-20 19:33:46 -06:00
Mark Chambers
1ce3e0384a pgtype Int fix minimum error message.
Previously on the minimum condition the error would be:

  "is greater than maximum"

Also add encoding/json import into the .erb template as the import was
missing after running rake generate.
2022-12-17 09:10:02 -06:00
Alejandro Do Nascimento Mora
e58381ac94 Enable some CopyFrom tests for cockroachDB
CockroachDB added support for COPY in version 20.2.

https://www.cockroachlabs.com/docs/v20.2/copy-from

There are some limitations on the implementation, that's why not all the
existing tests were enabled.
2022-12-12 18:22:32 -06:00
Jack Christensen
279c3c0a20 Fix: json values work with sql.Scanner
https://github.com/jackc/pgx/issues/1418
2022-12-06 19:44:55 -06:00
Jack Christensen
17f8f7af63 Release v5.2.0 2022-12-05 20:41:55 -06:00
Jack Christensen
f0a73424b1 Fix: Scan uint and uint64 ScanNumeric
fixes https://github.com/jackc/pgx/issues/1414
2022-12-05 20:34:46 -06:00
Vitalii Solodilov
88b373f9ee Skipped multirange tests for postgres less than 14 version 2022-12-01 19:33:33 -06:00
Vitalii Solodilov
8e2de2fefa Conn.LoadType supports range and multirange types (#1393)
Closes #1393
2022-12-01 19:33:33 -06:00
Nazar Vovk
24c53259f8 Fix typo 2022-11-28 09:36:20 -06:00
ksco
8eb062f588 perf(tx): use strings.Builder to avoid the overhead of []byte -> string conversion 2022-11-25 12:39:22 -06:00
Petr Evdokimov
fbfafb3edf Optimize 'beginSQL' runtime and memory allocations 2022-11-22 09:00:12 -06:00
Vitalii Solodilov
174224fa07 The tracelog.TraceLog implements the pgx.PrepareTracer interface
Allows Tracelog to log the Prepare queiries.

Fixes #1383

Unit tests:
* added logger.Clear method to cleanup old log messages
* added logger.FilterByMsg to get only specific logs for assertions. When quieries are executed using different query exec methods prepare query can be not executed. So we can get different number of logs using different exec methods.
2022-11-19 07:43:39 -06:00
Jack Christensen
8ad1394f4c Update changelog for v5.1.1 2022-11-17 19:47:09 -06:00
Bodo Kaiser
56633b3d51 removed unnecessary name argument from DeallocateAll 2022-11-17 19:41:18 -06:00
Jack Christensen
ba4bbf92af Fix query sanitizer
...when query text has contains Unicode replacement character.
uft8.RuneError actually is a valid character.
2022-11-14 18:32:26 -06:00
Jack Christensen
b4d2eae777 Update changelog 2022-11-12 11:02:55 -06:00
Bodo Kaiser
3520c2ea43 updated DeallocateAll to also reset client-side statement and description cache 2022-11-12 10:57:31 -06:00
Bodo Kaiser
c94c47f584 added DeallocateAll to pgx.Conn to clear prepared statement cache 2022-11-12 10:57:31 -06:00
Jack Christensen
8678ed560f Update puddle to v2.1.2 2022-11-12 10:42:08 -06:00
Jack Christensen
05924a9d6b Update CONTRIBUTING.md 2022-11-12 10:42:02 -06:00
Jack Christensen
2e9e2865f9 Added more docs and tests 2022-11-12 10:13:20 -06:00
Pavlo Golub
14be51536b implement RowToStructByName and RowToAddrOfStructByName 2022-11-12 09:39:54 -06:00
Jack Christensen
1376a2c0ed Update Go doc badge 2022-11-12 09:23:07 -06:00
Jack Christensen
932f676cfd Remove PG 10 from CI and add PG 15 to CI
PG 10 is now out of support.
2022-11-12 09:20:48 -06:00
Jack Christensen
5b6fb75669 Conn.LoadType supports domain types
If the underlying type is registered then use the same Codec.

fixes https://github.com/jackc/pgx/issues/1373
2022-11-12 08:11:37 -06:00
Jack Christensen
b265fedd75 Correct error message 2022-11-12 07:06:54 -06:00
Jack Christensen
871f14e43b Fix text decoding of dates with 5 digit years 2022-11-12 07:01:11 -06:00
Jack Christensen
071d1c9467 DateCodec.DecodeValue can return pgtype.InfinityModifier
Previously, an infinite value was returned as a string. Other types
that can be infinite such as Timestamptz return a
pgtype.InfinityModifier. This change brings them into alignment.
2022-11-12 06:27:41 -06:00
Jack Christensen
29109487ec DateCodec.DecodeDatabaseSQLValue returns time.Time when possible
Previously it returned a string. However, this was an unintended
behavior change from pgx v4.

89f69aaea9 (commitcomment-89173737)
2022-11-12 06:21:48 -06:00
Jack Christensen
daf570c752 Date text encoding pads year with 0 for at least 4 digits
e.g. 0007-01-02 instead of 7-01-02

89f69aaea9 (commitcomment-89173737)
2022-11-12 06:14:04 -06:00
Jack Christensen
a86acf61e0 Fix encode ErrorResponse
fixes https://github.com/jackc/pgx/issues/1371
2022-11-11 18:20:16 -06:00
Jack Christensen
a968ce3437 Add typed nil behavior change note to changelog
https://github.com/jackc/pgx/issues/1367
2022-11-03 21:24:44 -05:00
Jack Christensen
39676004de Fix logger string truncation with UTF-8
fixes #1365
2022-11-03 20:50:30 -05:00
Jack Christensen
6f90866f58 Expose underlying pgconn GetSSLPassword support to pgx
pgconn supports a GetSSLPassword function but the pgx connection
functions did not expose a means of using it.

See PR #1233 for more context.
2022-11-03 20:09:52 -05:00
Jack Christensen
d8c04249d1 Give up on that test in CI
The test works if I use upterm and run manually on the CI server...
TLS is tested in the TLS with client certificate tests anyway.
2022-10-31 22:37:05 -05:00
Jack Christensen
7fd064ab80 Disable upterm 2022-10-31 22:28:50 -05:00
Jack Christensen
0013f6c7ca Enable upterm 2022-10-31 22:20:58 -05:00
Jack Christensen
95498282bb more ci 2022-10-31 22:10:37 -05:00
Jack Christensen
6e77e0a09d Fight with CI some more 2022-10-31 22:05:35 -05:00
Jack Christensen
1f0fd66623 Go back to Ubuntu 20.04 on CI
Should fix some strange openssl / TLS issues.
2022-10-31 21:57:38 -05:00
Jack Christensen
45aeaed20a Remove unused pg-version matrix 2022-10-31 21:28:58 -05:00
Jack Christensen
a2da398dff Partial CI fixes 2022-10-31 21:24:57 -05:00
Jack Christensen
be419e25b4 Use des3 for certs in testing / CI 2022-10-31 19:42:22 -05:00
Jack Christensen
dd07e24a6c sudo the CI 2022-10-31 19:34:59 -05:00
Jack Christensen
0920c79b02 Test SCRAM, sslmode=verify-full and client cert auth on CI 2022-10-31 19:30:22 -05:00
Jack Christensen
268af3903c Upgrade CI to ubuntu-22.04 2022-10-31 19:10:49 -05:00
Jack Christensen
4d711aaa73 Remove v5-dev branch from CI 2022-10-31 19:10:21 -05:00
Jack Christensen
dc85718658 Remove unused code from CI script 2022-10-29 19:02:04 -05:00
Jack Christensen
6b52e0b5e0 Contributing guide now includes instructions to test client ssl auth 2022-10-29 19:00:29 -05:00
Jack Christensen
9eaeb51e30 Fix CI PostgreSQL user permissions 2022-10-29 17:55:13 -05:00
Jack Christensen
8b2ac8c18f Fix unix domain socket tests on CI 2022-10-29 17:45:13 -05:00
Jack Christensen
05e9234c2e Upgrade setup-go and checkout actions to v3 2022-10-29 17:29:10 -05:00
Jack Christensen
97d1012f42 Use testsetup/postgresql_setup.sql in CI 2022-10-29 17:27:39 -05:00
Jack Christensen
6bedfa7def Use testsetup/pg_hba.conf in CI 2022-10-29 17:23:13 -05:00
Jack Christensen
55b5067ddd Improve testing / contributing instructions
* Extract CONTRIBUTING.md
* Add instructions and scripts to setup standalone PostgreSQL server
  that tests the various connection and authentication types.
2022-10-29 17:14:09 -05:00
Jack Christensen
1ec3816a20 pgconn and pgproto use same environment variable for tests as pgx 2022-10-29 13:23:25 -05:00
Jack Christensen
c9c166b8b2 Fix TestConnCopyFromDataWriteAfterErrorAndReturn always being skipped 2022-10-29 13:17:52 -05:00
Jack Christensen
9a207178f6 Fix TestConnCheckConn always being skipped 2022-10-29 13:16:05 -05:00
Jack Christensen
3feeddd9f1 Fix tests when PGUSER is different than OS user 2022-10-29 13:12:03 -05:00
Jack Christensen
72c89108ad Fix tests when PGPORT set to non-default value 2022-10-29 13:06:53 -05:00
Jack Christensen
c130b2d74a Update CopyFrom documentation to be clearer
Regarding binary requirement and enums in particular.

https://github.com/jackc/pgx/issues/1338
2022-10-29 09:48:45 -05:00
Jack Christensen
7d3b9c1e44 QueryRewriter.RewriteQuery now returns an error
https://github.com/jackc/pgx/issues/1186#issuecomment-1288207250
2022-10-29 09:33:13 -05:00
Jack Christensen
6515e183ff Update doc example for pgx.ForEachRow
fixes https://github.com/jackc/pgx/issues/1360
2022-10-29 08:59:57 -05:00
Jack Christensen
e35041372d Remove mistakenly included replace directive in go.mod 2022-10-29 08:56:49 -05:00
Jack Christensen
6fabd8f5b1 Fix encoding uint64 larger than math.MaxInt64 into numeric
fixes https://github.com/jackc/pgx/issues/1357
2022-10-29 08:47:12 -05:00
Jack Christensen
c00fb5d2a1 Upgrade to puddle v2.0.1 2022-10-29 08:09:54 -05:00
Jack Hopner
55d5d036c0 add pgx xray tracer to readme 2022-10-27 19:42:36 -05:00
Jack Christensen
987de3874e Update changelog 2022-10-24 19:11:50 -05:00
Jack Christensen
3ad9995dfe Exec checks if tx is closed
https://github.com/jackc/pgx/discussions/1350
2022-10-24 18:23:26 -05:00
Baptiste Fontaine
3e825ec898 Fix RowToStructByPos on structs with multiple anonymous sub-structs
Fixes #1343
2022-10-22 10:02:32 -05:00
Jeff Koenig
ba100785cc fix: bump text package for CVE-2022-32149
https://security.snyk.io/vuln/SNYK-GOLANG-GOLANGORGXTEXTLANGUAGE-3043869
2022-10-22 09:07:24 -05:00
Jack Christensen
48b4807b33 Fix some reflect Kind checks to first check for nil
fixes https://github.com/jackc/pgx/issues/1335
2022-10-22 08:57:49 -05:00
Jack Christensen
6e40968cfc CollectOneRow prefers PostgreSQL error over pgx.ErrorNoRows
fixes https://github.com/jackc/pgx/issues/1334
2022-10-22 08:44:06 -05:00
Jack Christensen
11e5f68ff6 Update changelog for v5.0.3 2022-10-14 19:11:11 -05:00
Baptiste Fontaine
7a9e70d1e0 Fix some bad rows.Err() handlings in tests 2022-10-14 19:02:44 -05:00
Jack Christensen
f2e7c8144d reflect.TypeOf can return nil. Check before using
https://github.com/jackc/pgx/issues/1331
2022-10-12 20:03:51 -05:00
Jack Christensen
aff180b192 Remove dead code 2022-10-12 19:58:06 -05:00
Jack Christensen
a581124dea Encode with driver.Valuer after trying TryWrapEncodePlanFuncs
However, all builtin TryWrapEncodePlanFuncs check for driver.Valuer and
skip themselves if it is found.
2022-10-12 19:52:57 -05:00
Jack Christensen
c4407fb36e Prevent infinite loop for driver.Valuer / Codec edge case
A `driver.Valuer()` results in a `string` that the `Codec` for the
PostgreSQL type doesn't know how to handle. That string is scanned into
whatever the default type for that `Codec` is. That new value is
encoded. If the new value is the same type as the original type than an
infinite loop occured. Check that the types are different.

https://github.com/jackc/pgx/issues/1331
2022-10-12 19:46:15 -05:00
Jack Christensen
094ad9c9d8 Update changelog for v5.0.2 2022-10-08 18:58:17 -05:00
Jack Christensen
af0b896290 Allow scanning null even if PG and Go types are incompatible
refs https://github.com/jackc/pgx/issues/1326
2022-10-08 09:10:43 -05:00
Jack Christensen
5655f9d593 Fix scan to pointer to pointer to renamed type
refs https://github.com/jackc/pgx/issues/1326
2022-10-08 08:10:40 -05:00
Jack Christensen
f803c790d0 Fix docs for listen / notify
https://github.com/jackc/pgx/issues/1318
2022-10-01 12:58:49 -05:00
Jack Christensen
222e3b37bc Prefer driver.Value over wrap plans when encoding
This is tricky due to driver.Valuer returning any. For example, we can
plan for fmt.Stringer because it always returns a string.

Because of this driver.Valuer was always handled as the last option. But
with pgx v5 now having the ability to find underlying types like a
string and supporting fmt.Stringer it meant that driver.Valuer was
often not getting called because something else was found first.

This change tries driver.Valuer immediately after the initial PlanScan
for the Codec. So a type that directly implements a pgx interface should
be used, but driver.Valuer will be prefered before all the attempts to
handle renamed types, pointer deferencing, etc.

fixes https://github.com/jackc/pgx/issues/1319
fixes https://github.com/jackc/pgx/issues/1311
2022-10-01 12:20:23 -05:00
Jack Christensen
89f69aaea9 Date text encoding includes leading zero for month and day
e.g. 2000-01-01 instead of 2000-1-1. PostgreSQL accepted it without
zeroes but our text decoder didn't. This caused a problem when we needed
to take a value and encode to text so something else could parse it as
if it had come from the PostgreSQL server in text format. e.g.
database/sql compatibility.
2022-10-01 10:41:40 -05:00
Jack Christensen
63ae730fe8 Upgrade CockroachDB on CI 2022-10-01 10:11:11 -05:00
Jack Christensen
305c4ddbc7 Move and rename test 2022-10-01 10:09:57 -05:00
Jack Christensen
fb83fb0cc3 Skip TestCopyFrom on CockroachDB 2022-10-01 10:08:03 -05:00
Tommy Reilly
c48dd7e1f8 Add a test case demonstrating I/O race with CopyFrom 2022-10-01 10:07:38 -05:00
Jack Christensen
cd8b29b0fe Fix flickering on TestConnectTimeoutStuckOnTLSHandshake
Ensure that even if the outer function finishes the goroutine can still
send an error.
2022-09-24 12:54:59 -05:00
Jack Christensen
0aa681f3a3 Update changelog for v5.0.1 2022-09-24 11:15:31 -05:00
Jack Christensen
335c8621ff Fix sqlScannerWrapper NULL handling
https://github.com/jackc/pgx/issues/1312
2022-09-24 10:30:12 -05:00
Jack Christensen
ac9d4f4d96 Encode text for Lseg includes [ and ]
https://github.com/jackc/pgtype/issues/187
2022-09-24 10:30:12 -05:00
yogipristiawan
72e4b88e56 feat: add marshalJSON for float8 type 2022-09-24 10:00:40 -05:00
Peter Feichtinger
639fb28846 Fix typo 2022-09-24 09:26:52 -05:00
Jack Christensen
d7c7ddc594 Fix Windows 386 atomic usage
https://github.com/jackc/pgx/issues/1307
2022-09-24 09:23:36 -05:00
Jack Christensen
4fc4f9a603 Remove spurious .travis.yml 2022-09-17 10:36:36 -05:00
Jack Christensen
23a59d68fc Merge branch 'v5-dev' 2022-09-17 10:35:32 -05:00
Jack Christensen
5a055434f2 Upgrade dependencies 2022-09-17 10:24:19 -05:00
Jack Christensen
1a314bda3b pgconn.Timeout() no longer considers context.Canceled as a timeout error.
https://github.com/jackc/pgconn/issues/81
2022-09-17 10:18:06 -05:00
Jack Christensen
4f1a8084f1 Various doc and changelog tweaks 2022-09-17 09:03:48 -05:00
Jack Christensen
a05fb80b8a Update docs and changelog for renamed pgxpool.NewWithConfig
fixes https://github.com/jackc/pgx/issues/1306
2022-09-16 18:16:36 -05:00
Jack Christensen
90b69c0ee0 Fix atomic alignment on 32-bit platforms
refs #1288
2022-09-08 20:43:53 -05:00
Jack Christensen
ee2622a8e6 RowToStructByPos supports embedded structs
https://github.com/jackc/pgx/issues/1273#issuecomment-1236966785
2022-09-06 18:32:10 -05:00
Jack Christensen
d42b399be3 Update changelog 2022-09-03 13:42:36 -05:00
Jack Christensen
f015ced1bf Use puddle v2.0.0-beta.2 for Acquire in background after cancel 2022-09-03 13:20:19 -05:00
Jack Christensen
782133158f Test sending CopyData before CopyFrom responds with error 2022-09-03 09:31:41 -05:00
Tom Möller
dfce986bb5 Fix panic when logging batch error 2022-09-03 09:02:23 -05:00
Jack Christensen
f8d088cfb6 Fix JSON scan not completely overwriting destination
See https://github.com/jackc/pgtype/pull/185 for original report in
pgx v4 / pgtype.
2022-09-02 18:37:02 -05:00
Jack Christensen
f5cdf0d383 Update changelog 2022-08-27 18:18:41 -05:00
Jack Christensen
72fe594942 Upgrade to puddle v1.3.0 2022-08-27 18:18:34 -05:00
Jack Christensen
bce26b85d1 Fix atomic alignment on 32-bit platforms
refs #1288
2022-08-27 09:23:17 -05:00
Jack Christensen
bb6c997102 Add NewCommandTag
Useful for mocking and testing.

https://github.com/jackc/pgx/issues/1273#issuecomment-1224154013
2022-08-23 19:39:15 -05:00
Jack Christensen
fe3a4f3150 Standardize casing for NULL in error messages 2022-08-22 21:01:18 -05:00
Jack Christensen
2e73d1e8ee Improve error message when failing to scan a NULL::json 2022-08-22 20:56:36 -05:00
Jack Christensen
0d5d8e0137 Fallback to other format when encoding query arguments
The preferred format may not be possible for certain arguments. For
example, the preferred format for numeric is binary. But if
shopspring/decimal is being used without jackc/pgx-shopspring-decimal
then it will use the database/sql/driver.Valuer interface. This will
return a string. That string should be sent in the text format.

A similar case occurs when encoding a []string into a non-text
PostgreSQL array such as uuid[].
2022-08-22 20:26:38 -05:00
Jack Christensen
ae65a8007b Use higher pgconn.FieldDescription with string Name
Instead of using pgproto3.FieldDescription through pgconn and pgx. This
lets the lowest level pgproto3 still be as memory efficient as possible.

https://github.com/jackc/pgx/pull/1281
2022-08-20 10:04:18 -05:00
Jack Christensen
dbee461dc9 Update previous pgconn merge for v5 2022-08-19 17:42:04 -05:00
Jack Christensen
ef5655c563 Merge remote-tracking branch 'pgconn/master' into v5-dev 2022-08-19 17:36:29 -05:00
Stas Kelvich
15f8e6323e Fix tests that check tls.Config.ServerName -- with SNI this field
is filled, unless SNI is delibaretely disabled. Also, do not set
SNI when host is an IP address as per RFC 6066.
2022-08-19 17:35:33 -05:00
Stas Kelvich
e3406d95f9 Add test coverage for client SNI 2022-08-19 17:35:33 -05:00
Stas Kelvich
067771b2e6 Set SNI for SSL connections
This allows an SNI-aware proxy to route connections. Patch adds a new
connection option (`sslsni`) to opt out of the SNI, to have the same
behavior as `libpq` does. See more in `sslsni` sections at
<https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS>.
2022-08-19 17:35:33 -05:00
Jack Christensen
8eae4a2a3e Merge remote-tracking branch 'pgconn/master' into v5-dev 2022-08-13 10:19:49 -05:00
Nathan Giardina
faabb0696f Fix for timeout when a single node has timed out, created a new context to allow for each db node to timeout individually 2022-08-13 10:18:55 -05:00
Jack Christensen
1d748d9bbf Failsafe timeout for background pool connections
Do not override existing connect timeout.
2022-08-13 09:50:37 -05:00
Jack Christensen
c842802d65 Failsafe timeout for background pool connections
Do not override existing connect timeout.
2022-08-13 09:49:06 -05:00
Jack Christensen
7c6a31f9d2 CopyFrom parses strings to encode into binary format
https://github.com/jackc/pgx/issues/1277
https://github.com/jackc/pgx/issues/1267
2022-08-13 09:30:29 -05:00
Jack Christensen
02d9a5acd8 Fix naming of some tests 2022-08-13 08:41:06 -05:00
Jack Christensen
8256ab147f Add build tag to skip default PG type registration
https://github.com/jackc/pgx/issues/1273#issuecomment-1207338136
2022-08-13 08:09:44 -05:00
Jack Christensen
906f709e0c Fix typo in Windows code
https://github.com/jackc/pgx/issues/1274
2022-08-11 20:59:37 -05:00
Jack Christensen
33b782a96d Potential fix for Windows
https://github.com/jackc/pgx/issues/1274
2022-08-11 20:55:50 -05:00
Jack Christensen
1453cd4b97 Update v5 status 2022-08-06 07:11:11 -05:00
Jack Christensen
6871a0c4a6 Add v5 testing note to readme 2022-08-06 07:10:37 -05:00
234 changed files with 15651 additions and 5476 deletions

View File

@ -23,7 +23,7 @@ import (
"log" "log"
"os" "os"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v5"
) )
func main() { func main() {
@ -37,6 +37,8 @@ func main() {
} }
``` ```
Please run your example with the race detector enabled. For example, `go run -race main.go` or `go test -race`.
**Expected behavior** **Expected behavior**
A clear and concise description of what you expected to happen. A clear and concise description of what you expected to happen.

View File

@ -2,87 +2,155 @@ name: CI
on: on:
push: push:
branches: [ master, v5-dev ] branches: [master]
pull_request: pull_request:
branches: [master] branches: [master]
jobs: jobs:
test: test:
name: Test name: Test
runs-on: ubuntu-20.04 runs-on: ubuntu-22.04
strategy: strategy:
matrix: matrix:
go-version: [1.18] go-version: ["1.23", "1.24"]
pg-version: [10, 11, 12, 13, 14, cockroachdb] pg-version: [13, 14, 15, 16, 17, cockroachdb]
include: include:
- pg-version: 10
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
- pg-version: 11
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
- pg-version: 12
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test
- pg-version: 13 - pg-version: 13
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: 14 - pg-version: 14
pgx-test-database: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test" pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-tls-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: postgres://pgx_md5:secret@127.0.0.1/pgx_test pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: postgres://pgx_pw:secret@127.0.0.1/pgx_test pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: 15
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: 16
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: 17
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
pgx-ssl-password: certpw
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
- pg-version: cockroachdb - pg-version: cockroachdb
pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on" pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
pgx-test-conn-string: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
steps: steps:
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: Set up Go 1.x - name: Set up Go ${{ matrix.go-version }}
uses: actions/setup-go@v2 uses: actions/setup-go@v5
with: with:
go-version: ${{ matrix.go-version }} go-version: ${{ matrix.go-version }}
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- name: Setup database server for testing - name: Setup database server for testing
run: ci/setup_test.bash run: ci/setup_test.bash
env: env:
PGVERSION: ${{ matrix.pg-version }} PGVERSION: ${{ matrix.pg-version }}
# - name: Setup upterm session
# uses: lhotari/action-upterm@v1
# with:
# ## limits ssh access and adds the ssh public key for the user which triggered the workflow
# limit-access-to-actor: true
# env:
# PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
# PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
# PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }}
# PGX_TEST_SCRAM_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-scram-password-conn-string }}
# PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
# PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}
# PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
# PGX_SSL_PASSWORD: ${{ matrix.pgx-ssl-password }}
# PGX_TEST_TLS_CLIENT_CONN_STRING: ${{ matrix.pgx-test-tls-client-conn-string }}
- name: Check formatting
run: |
gofmt -l -s -w .
git status
git diff --exit-code
- name: Test - name: Test
run: go test -race ./... # parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner.
run: go test -parallel=1 -race ./...
env: env:
PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }} PGX_TEST_DATABASE: ${{ matrix.pgx-test-database }}
PGX_TEST_CONN_STRING: ${{ matrix.pgx-test-conn-string }}
PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }} PGX_TEST_UNIX_SOCKET_CONN_STRING: ${{ matrix.pgx-test-unix-socket-conn-string }}
PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }} PGX_TEST_TCP_CONN_STRING: ${{ matrix.pgx-test-tcp-conn-string }}
PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }} PGX_TEST_SCRAM_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-scram-password-conn-string }}
PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }} PGX_TEST_MD5_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-md5-password-conn-string }}
PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }} PGX_TEST_PLAIN_PASSWORD_CONN_STRING: ${{ matrix.pgx-test-plain-password-conn-string }}
# TestConnectTLS fails. However, it succeeds if I connect to the CI server with upterm and run it. Give up on that test for now.
# PGX_TEST_TLS_CONN_STRING: ${{ matrix.pgx-test-tls-conn-string }}
PGX_SSL_PASSWORD: ${{ matrix.pgx-ssl-password }}
PGX_TEST_TLS_CLIENT_CONN_STRING: ${{ matrix.pgx-test-tls-client-conn-string }}
test-windows:
name: Test Windows
runs-on: windows-latest
strategy:
matrix:
go-version: ["1.23", "1.24"]
steps:
- name: Setup PostgreSQL
id: postgres
uses: ikalnytskyi/action-setup-postgres@v4
with:
database: pgx_test
- name: Check out code into the Go module directory
uses: actions/checkout@v4
- name: Set up Go ${{ matrix.go-version }}
uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go-version }}
- name: Initialize test database
run: |
psql -f testsetup/postgresql_setup.sql pgx_test
env:
PGSERVICE: ${{ steps.postgres.outputs.service-name }}
shell: bash
- name: Test
# parallel testing is disabled because somehow parallel testing causes Github Actions to kill the runner.
run: go test -parallel=1 -race ./...
env:
PGX_TEST_DATABASE: ${{ steps.postgres.outputs.connection-uri }}

3
.gitignore vendored
View File

@ -22,3 +22,6 @@ _testmain.go
*.exe *.exe
.envrc .envrc
/.testdb
.DS_Store

View File

@ -1,9 +0,0 @@
language: go
go:
- 1.x
- tip
matrix:
allow_failures:
- go: tip

View File

@ -1,4 +1,286 @@
# Unreleased v5 # 5.7.5 (May 17, 2025)
* Support sslnegotiation connection option (divyam234)
* Update golang.org/x/crypto to v0.37.0. This placates security scanners that were unable to see that pgx did not use the behavior affected by https://pkg.go.dev/vuln/GO-2025-3487.
* TraceLog now logs Acquire and Release at the debug level (dave sinclair)
* Add support for PGTZ environment variable
* Add support for PGOPTIONS environment variable
* Unpin memory used by Rows quicker
* Remove PlanScan memoization. This resolves a rare issue where scanning could be broken for one type by first scanning another. The problem was in the memoization system and benchmarking revealed that memoization was not providing any meaningful benefit.
# 5.7.4 (March 24, 2025)
* Fix / revert change to scanning JSON `null` (Felix Röhrich)
# 5.7.3 (March 21, 2025)
* Expose EmptyAcquireWaitTime in pgxpool.Stat (vamshiaruru32)
* Improve SQL sanitizer performance (ninedraft)
* Fix Scan confusion with json(b), sql.Scanner, and automatic dereferencing (moukoublen, felix-roehrich)
* Fix Values() for xml type always returning nil instead of []byte
* Add ability to send Flush message in pipeline mode (zenkovev)
* Fix pgtype.Timestamp's JSON behavior to match PostgreSQL (pconstantinou)
* Better error messages when scanning structs (logicbomb)
* Fix handling of error on batch write (bonnefoa)
* Match libpq's connection fallback behavior more closely (felix-roehrich)
* Add MinIdleConns to pgxpool (djahandarie)
# 5.7.2 (December 21, 2024)
* Fix prepared statement already exists on batch prepare failure
* Add commit query to tx options (Lucas Hild)
* Fix pgtype.Timestamp json unmarshal (Shean de Montigny-Desautels)
* Add message body size limits in frontend and backend (zene)
* Add xid8 type
* Ensure planning encodes and scans cannot infinitely recurse
* Implement pgtype.UUID.String() (Konstantin Grachev)
* Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions (Alexander Rumyantsev)
* Update golang.org/x/crypto
* Fix json(b) columns prefer sql.Scanner interface like database/sql (Ludovico Russo)
# 5.7.1 (September 10, 2024)
* Fix data race in tracelog.TraceLog
* Update puddle to v2.2.2. This removes the import of nanotime via linkname.
* Update golang.org/x/crypto and golang.org/x/text
# 5.7.0 (September 7, 2024)
* Add support for sslrootcert=system (Yann Soubeyrand)
* Add LoadTypes to load multiple types in a single SQL query (Nick Farrell)
* Add XMLCodec supports encoding + scanning XML column type like json (nickcruess-soda)
* Add MultiTrace (Stepan Rabotkin)
* Add TraceLogConfig with customizable TimeKey (stringintech)
* pgx.ErrNoRows wraps sql.ErrNoRows to aid in database/sql compatibility with native pgx functions (merlin)
* Support scanning binary formatted uint32 into string / TextScanner (jennifersp)
* Fix interval encoding to allow 0s and avoid extra spaces (Carlos Pérez-Aradros Herce)
* Update pgservicefile - fixes panic when parsing invalid file
* Better error message when reading past end of batch
* Don't print url when url.Parse returns an error (Kevin Biju)
* Fix snake case name normalization collision in RowToStructByName with db tag (nolandseigler)
* Fix: Scan and encode types with underlying types of arrays
# 5.6.0 (May 25, 2024)
* Add StrictNamedArgs (Tomas Zahradnicek)
* Add support for macaddr8 type (Carlos Pérez-Aradros Herce)
* Add SeverityUnlocalized field to PgError / Notice
* Performance optimization of RowToStructByPos/Name (Zach Olstein)
* Allow customizing context canceled behavior for pgconn
* Add ScanLocation to pgtype.Timestamp[tz]Codec
* Add custom data to pgconn.PgConn
* Fix ResultReader.Read() to handle nil values
* Do not encode interval microseconds when they are 0 (Carlos Pérez-Aradros Herce)
* pgconn.SafeToRetry checks for wrapped errors (tjasko)
* Failed connection attempts include all errors
* Optimize LargeObject.Read (Mitar)
* Add tracing for connection acquire and release from pool (ngavinsir)
* Fix encode driver.Valuer not called when nil
* Add support for custom JSON marshal and unmarshal (Mitar)
* Use Go default keepalive for TCP connections (Hans-Joachim Kliemeck)
# 5.5.5 (March 9, 2024)
Use spaces instead of parentheses for SQL sanitization.
This still solves the problem of negative numbers creating a line comment, but this avoids breaking edge cases such as
`set foo to $1` where the substitution is taking place in a location where an arbitrary expression is not allowed.
# 5.5.4 (March 4, 2024)
Fix CVE-2024-27304
SQL injection can occur if an attacker can cause a single query or bind message to exceed 4 GB in size. An integer
overflow in the calculated message size can cause the one large message to be sent as multiple messages under the
attacker's control.
Thanks to Paul Gerste for reporting this issue.
* Fix behavior of CollectRows to return empty slice if Rows are empty (Felix)
* Fix simple protocol encoding of json.RawMessage
* Fix *Pipeline.getResults should close pipeline on error
* Fix panic in TryFindUnderlyingTypeScanPlan (David Kurman)
* Fix deallocation of invalidated cached statements in a transaction
* Handle invalid sslkey file
* Fix scan float4 into sql.Scanner
* Fix pgtype.Bits not making copy of data from read buffer. This would cause the data to be corrupted by future reads.
# 5.5.3 (February 3, 2024)
* Fix: prepared statement already exists
* Improve CopyFrom auto-conversion of text-ish values
* Add ltree type support (Florent Viel)
* Make some properties of Batch and QueuedQuery public (Pavlo Golub)
* Add AppendRows function (Edoardo Spadolini)
* Optimize convert UUID [16]byte to string (Kirill Malikov)
* Fix: LargeObject Read and Write of more than ~1GB at a time (Mitar)
# 5.5.2 (January 13, 2024)
* Allow NamedArgs to start with underscore
* pgproto3: Maximum message body length support (jeremy.spriet)
* Upgrade golang.org/x/crypto to v0.17.0
* Add snake_case support to RowToStructByName (Tikhon Fedulov)
* Fix: update description cache after exec prepare (James Hartig)
* Fix: pipeline checks if it is closed (James Hartig and Ryan Fowler)
* Fix: normalize timeout / context errors during TLS startup (Samuel Stauffer)
* Add OnPgError for easier centralized error handling (James Hartig)
# 5.5.1 (December 9, 2023)
* Add CopyFromFunc helper function. (robford)
* Add PgConn.Deallocate method that uses PostgreSQL protocol Close message.
* pgx uses new PgConn.Deallocate method. This allows deallocating statements to work in a failed transaction. This fixes a case where the prepared statement map could become invalid.
* Fix: Prefer driver.Valuer over json.Marshaler for json fields. (Jacopo)
* Fix: simple protocol SQL sanitizer previously panicked if an invalid $0 placeholder was used. This now returns an error instead. (maksymnevajdev)
* Add pgtype.Numeric.ScanScientific (Eshton Robateau)
# 5.5.0 (November 4, 2023)
* Add CollectExactlyOneRow. (Julien GOTTELAND)
* Add OpenDBFromPool to create *database/sql.DB from *pgxpool.Pool. (Lev Zakharov)
* Prepare can automatically choose statement name based on sql. This makes it easier to explicitly manage prepared statements.
* Statement cache now uses deterministic, stable statement names.
* database/sql prepared statement names are deterministically generated.
* Fix: SendBatch wasn't respecting context cancellation.
* Fix: Timeout error from pipeline is now normalized.
* Fix: database/sql encoding json.RawMessage to []byte.
* CancelRequest: Wait for the cancel request to be acknowledged by the server. This should improve PgBouncer compatibility. (Anton Levakin)
* stdlib: Use Ping instead of CheckConn in ResetSession
* Add json.Marshaler and json.Unmarshaler for Float4, Float8 (Kirill Mironov)
# 5.4.3 (August 5, 2023)
* Fix: QCharArrayOID was defined with the wrong OID (Christoph Engelbert)
* Fix: connect_timeout for sslmode=allow|prefer (smaher-edb)
* Fix: pgxpool: background health check cannot overflow pool
* Fix: Check for nil in defer when sending batch (recover properly from panic)
* Fix: json scan of non-string pointer to pointer
* Fix: zeronull.Timestamptz should use pgtype.Timestamptz
* Fix: NewConnsCount was not correctly counting connections created by Acquire directly. (James Hartig)
* RowTo(AddrOf)StructByPos ignores fields with "-" db tag
* Optimization: improve text format numeric parsing (horpto)
# 5.4.2 (July 11, 2023)
* Fix: RowScanner errors are fatal to Rows
* Fix: Enable failover efforts when pg_hba.conf disallows non-ssl connections (Brandon Kauffman)
* Hstore text codec internal improvements (Evan Jones)
* Fix: Stop timers for background reader when not in use. Fixes memory leak when closing connections (Adrian-Stefan Mares)
* Fix: Stop background reader as soon as possible.
* Add PgConn.SyncConn(). This combined with the above fix makes it safe to directly use the underlying net.Conn.
# 5.4.1 (June 18, 2023)
* Fix: concurrency bug with pgtypeDefaultMap and simple protocol (Lev Zakharov)
* Add TxOptions.BeginQuery to allow overriding the default BEGIN query
# 5.4.0 (June 14, 2023)
* Replace platform specific syscalls for non-blocking IO with more traditional goroutines and deadlines. This returns to the v4 approach with some additional improvements and fixes. This restores the ability to use a pgx.Conn over an ssh.Conn as well as other non-TCP or Unix socket connections. In addition, it is a significantly simpler implementation that is less likely to have cross platform issues.
* Optimization: The default type registrations are now shared among all connections. This saves about 100KB of memory per connection. `pgtype.Type` and `pgtype.Codec` values are now required to be immutable after registration. This was already necessary in most cases but wasn't documented until now. (Lev Zakharov)
* Fix: Ensure pgxpool.Pool.QueryRow.Scan releases connection on panic
* CancelRequest: don't try to read the reply (Nicola Murino)
* Fix: correctly handle bool type aliases (Wichert Akkerman)
* Fix: pgconn.CancelRequest: Fix unix sockets: don't use RemoteAddr()
* Fix: pgx.Conn memory leak with prepared statement caching (Evan Jones)
* Add BeforeClose to pgxpool.Pool (Evan Cordell)
* Fix: various hstore fixes and optimizations (Evan Jones)
* Fix: RowToStructByPos with embedded unexported struct
* Support different bool string representations (Lev Zakharov)
* Fix: error when using BatchResults.Exec on a select that returns an error after some rows.
* Fix: pipelineBatchResults.Exec() not returning error from ResultReader
* Fix: pipeline batch results not closing pipeline when error occurs while reading directly from results instead of using
a callback.
* Fix: scanning a table type into a struct
* Fix: scan array of record to pointer to slice of struct
* Fix: handle null for json (Cemre Mengu)
* Batch Query callback is called even when there is an error
* Add RowTo(AddrOf)StructByNameLax (Audi P. Risa P)
# 5.3.1 (February 27, 2023)
* Fix: Support v4 and v5 stdlib in same program (Tomáš Procházka)
* Fix: sql.Scanner not being used in certain cases
* Add text format jsonpath support
* Fix: fake non-blocking read adaptive wait time
# 5.3.0 (February 11, 2023)
* Fix: json values work with sql.Scanner
* Fixed / improved error messages (Mark Chambers and Yevgeny Pats)
* Fix: support scan into single dimensional arrays
* Fix: MaxConnLifetimeJitter setting actually jitter (Ben Weintraub)
* Fix: driver.Value representation of bytea should be []byte not string
* Fix: better handling of unregistered OIDs
* CopyFrom can use query cache to avoid extra round trip to get OIDs (Alejandro Do Nascimento Mora)
* Fix: encode to json ignoring driver.Valuer
* Support sql.Scanner on renamed base type
* Fix: pgtype.Numeric text encoding of negative numbers (Mark Chambers)
* Fix: connect with multiple hostnames when one can't be resolved
* Upgrade puddle to remove dependency on uber/atomic and fix alignment issue on 32-bit platform
* Fix: scanning json column into **string
* Multiple reductions in memory allocations
* Fake non-blocking read adapts its max wait time
* Improve CopyFrom performance and reduce memory usage
* Fix: encode []any to array
* Fix: LoadType for composite with dropped attributes (Felix Röhrich)
* Support v4 and v5 stdlib in same program
* Fix: text format array decoding with string of "NULL"
* Prefer binary format for arrays
# 5.2.0 (December 5, 2022)
* `tracelog.TraceLog` implements the pgx.PrepareTracer interface. (Vitalii Solodilov)
* Optimize creating begin transaction SQL string (Petr Evdokimov and ksco)
* `Conn.LoadType` supports range and multirange types (Vitalii Solodilov)
* Fix scan `uint` and `uint64` `ScanNumeric`. This resolves a PostgreSQL `numeric` being incorrectly scanned into `uint` and `uint64`.
# 5.1.1 (November 17, 2022)
* Fix simple query sanitizer where query text contains a Unicode replacement character.
* Remove erroneous `name` argument from `DeallocateAll()`. Technically, this is a breaking change, but given that method was only added 5 days ago this change was accepted. (Bodo Kaiser)
# 5.1.0 (November 12, 2022)
* Update puddle to v2.1.2. This resolves a race condition and a deadlock in pgxpool.
* `QueryRewriter.RewriteQuery` now returns an error. Technically, this is a breaking change for any external implementers, but given the minimal likelihood that there are actually any external implementers this change was accepted.
* Expose `GetSSLPassword` support to pgx.
* Fix encode `ErrorResponse` unknown field handling. This would only affect pgproto3 being used directly as a proxy with a non-PostgreSQL server that included additional error fields.
* Fix date text format encoding with 5 digit years.
* Fix date values passed to a `sql.Scanner` as `string` instead of `time.Time`.
* DateCodec.DecodeValue can return `pgtype.InfinityModifier` instead of `string` for infinite values. This now matches the behavior of the timestamp types.
* Add domain type support to `Conn.LoadType()`.
* Add `RowToStructByName` and `RowToAddrOfStructByName`. (Pavlo Golub)
* Add `Conn.DeallocateAll()` to clear all prepared statements including the statement cache. (Bodo Kaiser)
# 5.0.4 (October 24, 2022)
* Fix: CollectOneRow prefers PostgreSQL error over pgx.ErrorNoRows
* Fix: some reflect Kind checks to first check for nil
* Bump golang.org/x/text dependency to placate snyk
* Fix: RowToStructByPos on structs with multiple anonymous sub-structs (Baptiste Fontaine)
* Fix: Exec checks if tx is closed
# 5.0.3 (October 14, 2022)
* Fix `driver.Valuer` handling edge cases that could cause infinite loop or crash
# v5.0.2 (October 8, 2022)
* Fix date encoding in text format to always use 2 digits for month and day
* Prefer driver.Valuer over wrap plans when encoding
* Fix scan to pointer to pointer to renamed type
* Allow scanning NULL even if PG and Go types are incompatible
# v5.0.1 (September 24, 2022)
* Fix 32-bit atomic usage
* Add MarshalJSON for Float8 (yogipristiawan)
* Add `[` and `]` to text encoding of `Lseg`
* Fix sqlScannerWrapper NULL handling
# v5.0.0 (September 17, 2022)
## Merged Packages ## Merged Packages
@ -22,9 +304,11 @@ pgconn now supports pipeline mode.
`*PgConn.ReceiveResults` removed. Use pipeline mode instead. `*PgConn.ReceiveResults` removed. Use pipeline mode instead.
`Timeout()` no longer considers `context.Canceled` as a timeout error. `context.DeadlineExceeded` still is considered a timeout error.
## pgxpool ## pgxpool
`Connect` and `ConnectConfig` have been renamed to `New` and `NewConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect. `Connect` and `ConnectConfig` have been renamed to `New` and `NewWithConfig` respectively. The `LazyConnect` option has been removed. Pools always lazily connect.
## pgtype ## pgtype
@ -33,7 +317,10 @@ The `pgtype` package has been significantly changed.
### NULL Representation ### NULL Representation
Previously, types had a `Status` field that could be `Undefined`, `Null`, or `Present`. This has been changed to a Previously, types had a `Status` field that could be `Undefined`, `Null`, or `Present`. This has been changed to a
`Valid` `bool` field to harmonize with how `database/sql` represents NULL and to make the zero value useable. `Valid` `bool` field to harmonize with how `database/sql` represents `NULL` and to make the zero value useable.
Previously, a type that implemented `driver.Valuer` would have the `Value` method called even on a nil pointer. All nils
whether typed or untyped now represent `NULL`.
### Codec and Value Split ### Codec and Value Split
@ -47,9 +334,9 @@ generally defined by implementing an interface that a particular `Codec` underst
### Array Types ### Array Types
All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This All array types are now handled by `ArrayCodec` instead of using code generation for each new array type. This also
significantly reduced the amount of code and the compiled binary size. This also means that less common array types such means that less common array types such as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional
as `point[]` are now supported. `Array[T]` supports PostgreSQL multi-dimensional arrays. arrays.
### Composite Types ### Composite Types
@ -63,7 +350,7 @@ easily be handled. Multirange types are handled similarly with `MultirangeCodec`
### pgxtype ### pgxtype
load data type moved to conn `LoadDataType` moved to `*Conn` as `LoadType`.
### Bytea ### Bytea
@ -97,7 +384,7 @@ This matches the convention set by `database/sql`. In addition, for comparable t
### 3rd Party Type Integrations ### 3rd Party Type Integrations
* Extracted integrations with github.com/shopspring/decimal and github.com/gofrs/uuid to * Extracted integrations with https://github.com/shopspring/decimal and https://github.com/gofrs/uuid to
https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. This trims https://github.com/jackc/pgx-shopspring-decimal and https://github.com/jackc/pgx-gofrs-uuid respectively. This trims
the pgx dependency tree. the pgx dependency tree.

121
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,121 @@
# Contributing
## Discuss Significant Changes
Before you invest a significant amount of time on a change, please create a discussion or issue describing your
proposal. This will help to ensure your proposed change has a reasonable chance of being merged.
## Avoid Dependencies
Adding a dependency is a big deal. While on occasion a new dependency may be accepted, the default answer to any change
that adds a dependency is no.
## Development Environment Setup
pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE`
environment variable. The `PGX_TEST_DATABASE` environment variable can either be a URL or key-value pairs. In addition,
the standard `PG*` environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to
simplify environment variable handling.
### Using an Existing PostgreSQL Cluster
If you already have a PostgreSQL development server this is the quickest way to start and run the majority of the pgx
test suite. Some tests will be skipped that require server configuration changes (e.g. those testing different
authentication methods).
Create and setup a test database:
```
export PGDATABASE=pgx_test
createdb
psql -c 'create extension hstore;'
psql -c 'create extension ltree;'
psql -c 'create domain uint64 as numeric(20,0);'
```
Ensure a `postgres` user exists. This happens by default in normal PostgreSQL installs, but some installation methods
such as Homebrew do not.
```
createuser -s postgres
```
Ensure your `PGX_TEST_DATABASE` environment variable points to the database you just created and run the tests.
```
export PGX_TEST_DATABASE="host=/private/tmp database=pgx_test"
go test ./...
```
This will run the vast majority of the tests, but some tests will be skipped (e.g. those testing different connection methods).
### Creating a New PostgreSQL Cluster Exclusively for Testing
The following environment variables need to be set both for initial setup and whenever the tests are run. (direnv is
highly recommended). Depending on your platform, you may need to change the host for `PGX_TEST_UNIX_SOCKET_CONN_STRING`.
```
export PGPORT=5015
export PGUSER=postgres
export PGDATABASE=pgx_test
export POSTGRESQL_DATA_DIR=postgresql
export PGX_TEST_DATABASE="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
export PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/private/tmp database=pgx_test"
export PGX_TEST_TCP_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
export PGX_TEST_SCRAM_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_scram password=secret database=pgx_test"
export PGX_TEST_MD5_PASSWORD_CONN_STRING="host=127.0.0.1 database=pgx_test user=pgx_md5 password=secret"
export PGX_TEST_PLAIN_PASSWORD_CONN_STRING="host=127.0.0.1 user=pgx_pw password=secret"
export PGX_TEST_TLS_CONN_STRING="host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem"
export PGX_SSL_PASSWORD=certpw
export PGX_TEST_TLS_CLIENT_CONN_STRING="host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=`pwd`/.testdb/ca.pem database=pgx_test sslcert=`pwd`/.testdb/pgx_sslcert.crt sslkey=`pwd`/.testdb/pgx_sslcert.key"
```
Create a new database cluster.
```
initdb --locale=en_US -E UTF-8 --username=postgres .testdb/$POSTGRESQL_DATA_DIR
echo "listen_addresses = '127.0.0.1'" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
echo "port = $PGPORT" >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
cat testsetup/postgresql_ssl.conf >> .testdb/$POSTGRESQL_DATA_DIR/postgresql.conf
cp testsetup/pg_hba.conf .testdb/$POSTGRESQL_DATA_DIR/pg_hba.conf
cd .testdb
# Generate CA, server, and encrypted client certificates.
go run ../testsetup/generate_certs.go
# Copy certificates to server directory and set permissions.
cp ca.pem $POSTGRESQL_DATA_DIR/root.crt
cp localhost.key $POSTGRESQL_DATA_DIR/server.key
chmod 600 $POSTGRESQL_DATA_DIR/server.key
cp localhost.crt $POSTGRESQL_DATA_DIR/server.crt
cd ..
```
Start the new cluster. This will be necessary whenever you are running pgx tests.
```
postgres -D .testdb/$POSTGRESQL_DATA_DIR
```
Setup the test database in the new cluster.
```
createdb
psql --no-psqlrc -f testsetup/postgresql_setup.sql
```
### PgBouncer
There are tests specific for PgBouncer that will be executed if `PGX_TEST_PGBOUNCER_CONN_STRING` is set.
### Optional Tests
pgx supports multiple connection types and means of authentication. These tests are optional. They will only run if the
appropriate environment variables are set. In addition, there may be tests specific to particular PostgreSQL versions,
non-PostgreSQL servers (e.g. CockroachDB), or connection poolers (e.g. PgBouncer). `go test ./... -v | grep SKIP` to see
if any tests are being skipped.

132
README.md
View File

@ -1,15 +1,12 @@
[![](https://godoc.org/github.com/jackc/pgx?status.svg)](https://pkg.go.dev/github.com/jackc/pgx/v5) [![Go Reference](https://pkg.go.dev/badge/github.com/jackc/pgx/v5.svg)](https://pkg.go.dev/github.com/jackc/pgx/v5)
[![Build Status](https://travis-ci.org/jackc/pgx.svg)](https://travis-ci.org/jackc/pgx) [![Build Status](https://github.com/jackc/pgx/actions/workflows/ci.yml/badge.svg)](https://github.com/jackc/pgx/actions/workflows/ci.yml)
# pgx - PostgreSQL Driver and Toolkit # pgx - PostgreSQL Driver and Toolkit
*This is the v5 development branch. It is still in active development and testing.*
pgx is a pure Go driver and toolkit for PostgreSQL. pgx is a pure Go driver and toolkit for PostgreSQL.
pgx aims to be low-level, fast, and performant, while also enabling PostgreSQL-specific features that the standard `database/sql` package does not allow for. The pgx driver is a low-level, high performance interface that exposes PostgreSQL-specific features such as `LISTEN` /
`NOTIFY` and `COPY`. It also includes an adapter for the standard `database/sql` interface.
The driver component of pgx can be used alongside the standard `database/sql` package.
The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol The toolkit component is a related set of packages that implement PostgreSQL functionality such as parsing the wire protocol
and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers, and type mapping between PostgreSQL and Go. These underlying packages can be used to implement alternative drivers,
@ -51,91 +48,55 @@ func main() {
See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information. See the [getting started guide](https://github.com/jackc/pgx/wiki/Getting-started-with-pgx) for more information.
## Choosing Between the pgx and database/sql Interfaces
It is recommended to use the pgx interface if:
1. The application only targets PostgreSQL.
2. No other libraries that require `database/sql` are in use.
The pgx interface is faster and exposes more features.
The `database/sql` interface only allows the underlying driver to return or receive the following types: `int64`,
`float64`, `bool`, `[]byte`, `string`, `time.Time`, or `nil`. Handling other types requires implementing the
`database/sql.Scanner` and the `database/sql/driver/driver.Valuer` interfaces which require transmission of values in text format. The binary format can be substantially faster, which is what the pgx interface uses.
## Features ## Features
pgx supports many features beyond what is available through `database/sql`:
* Support for approximately 70 different PostgreSQL types * Support for approximately 70 different PostgreSQL types
* Automatic statement preparation and caching * Automatic statement preparation and caching
* Batch queries * Batch queries
* Single-round trip query mode * Single-round trip query mode
* Full TLS connection control * Full TLS connection control
* Binary format support for custom types (allows for much quicker encoding/decoding) * Binary format support for custom types (allows for much quicker encoding/decoding)
* COPY protocol support for faster bulk data loads * `COPY` protocol support for faster bulk data loads
* Extendable logging support * Tracing and logging support
* Connection pool with after-connect hook for arbitrary connection setup * Connection pool with after-connect hook for arbitrary connection setup
* Listen / notify * `LISTEN` / `NOTIFY`
* Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings * Conversion of PostgreSQL arrays to Go slice mappings for integers, floats, and strings
* Hstore support * `hstore` support
* JSON and JSONB support * `json` and `jsonb` support
* Maps `inet` and `cidr` PostgreSQL types to `netip.Addr` and `netip.Prefix` * Maps `inet` and `cidr` PostgreSQL types to `netip.Addr` and `netip.Prefix`
* Large object support * Large object support
* NULL mapping to Null* struct or pointer to pointer * NULL mapping to pointer to pointer
* Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types * Supports `database/sql.Scanner` and `database/sql/driver.Valuer` interfaces for custom types
* Notice response handling * Notice response handling
* Simulated nested transactions with savepoints * Simulated nested transactions with savepoints
## Performance ## Choosing Between the pgx and database/sql Interfaces
There are three areas in particular where pgx can provide a significant performance advantage over the standard The pgx interface is faster. Many PostgreSQL specific features such as `LISTEN` / `NOTIFY` and `COPY` are not available
`database/sql` interface and other drivers: through the `database/sql` interface.
1. PostgreSQL specific types - Types such as arrays can be parsed much quicker because pgx uses the binary format. The pgx interface is recommended when:
2. Automatic statement preparation and caching - pgx will prepare and cache statements by default. This can provide an
significant free improvement to code that does not explicitly use prepared statements. Under certain workloads, it can 1. The application only targets PostgreSQL.
perform nearly 3x the number of queries per second. 2. No other libraries that require `database/sql` are in use.
3. Batched queries - Multiple queries can be batched together to minimize network round trips.
It is also possible to use the `database/sql` interface and convert a connection to the lower-level pgx interface as needed.
## Testing ## Testing
pgx tests naturally require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_DATABASE` environment See [CONTRIBUTING.md](./CONTRIBUTING.md) for setup instructions.
variable. The `PGX_TEST_DATABASE` environment variable can either be a URL or DSN. In addition, the standard `PG*` environment
variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify environment variable
handling.
### Example Test Environment ## Architecture
Connect to your PostgreSQL server and run: See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.com/watch?v=sXMSWhcHCf8) for a description of pgx architecture.
```
create database pgx_test;
```
Connect to the newly-created database and run:
```
create domain uint64 as numeric(20,0);
```
Now, you can run the tests:
```
PGX_TEST_DATABASE="host=/var/run/postgresql database=pgx_test" go test ./...
```
In addition, there are tests specific for PgBouncer that will be executed if `PGX_TEST_PGBOUNCER_CONN_STRING` is set.
## Supported Go and PostgreSQL Versions ## Supported Go and PostgreSQL Versions
~~pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.17 and higher and PostgreSQL 10 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).~~ pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.23 and higher and PostgreSQL 13 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
`v5` is targeted at Go 1.18+. The general release of `v5` is not planned until second half of 2022 so it is expected that the policy of supporting the two most recent versions of Go will be maintained or restored soon after its release.
## Version Policy ## Version Policy
pgx follows semantic versioning for the documented public API on stable releases. `v4` is the latest stable major version. pgx follows semantic versioning for the documented public API on stable releases. `v5` is the latest stable major version.
## PGX Family Libraries ## PGX Family Libraries
@ -159,8 +120,14 @@ pgerrcode contains constants for the PostgreSQL error codes.
* [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid) * [github.com/jackc/pgx-gofrs-uuid](https://github.com/jackc/pgx-gofrs-uuid)
* [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal) * [github.com/jackc/pgx-shopspring-decimal](https://github.com/jackc/pgx-shopspring-decimal)
* [github.com/twpayne/pgx-geos](https://github.com/twpayne/pgx-geos) ([PostGIS](https://postgis.net/) and [GEOS](https://libgeos.org/) via [go-geos](https://github.com/twpayne/go-geos))
* [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid) * [github.com/vgarvardt/pgx-google-uuid](https://github.com/vgarvardt/pgx-google-uuid)
## Adapters for 3rd Party Tracers
* [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer)
## Adapters for 3rd Party Loggers ## Adapters for 3rd Party Loggers
These adapters can be used with the tracelog package. These adapters can be used with the tracelog package.
@ -170,13 +137,50 @@ These adapters can be used with the tracelog package.
* [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus) * [github.com/jackc/pgx-logrus](https://github.com/jackc/pgx-logrus)
* [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap) * [github.com/jackc/pgx-zap](https://github.com/jackc/pgx-zap)
* [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog) * [github.com/jackc/pgx-zerolog](https://github.com/jackc/pgx-zerolog)
* [github.com/mcosta74/pgx-slog](https://github.com/mcosta74/pgx-slog)
* [github.com/kataras/pgx-golog](https://github.com/kataras/pgx-golog)
## 3rd Party Libraries with PGX Support ## 3rd Party Libraries with PGX Support
### [github.com/pashagolub/pgxmock](https://github.com/pashagolub/pgxmock)
pgxmock is a mock library implementing pgx interfaces.
pgxmock has one and only purpose - to simulate pgx behavior in tests, without needing a real database connection.
### [github.com/georgysavva/scany](https://github.com/georgysavva/scany) ### [github.com/georgysavva/scany](https://github.com/georgysavva/scany)
Library for scanning data from a database into Go structs and more. Library for scanning data from a database into Go structs and more.
### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5) ### [github.com/vingarcia/ksql](https://github.com/vingarcia/ksql)
A carefully designed SQL client for making using SQL easier,
more productive, and less error-prone on Golang.
### [github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
Adds GSSAPI / Kerberos authentication support. Adds GSSAPI / Kerberos authentication support.
### [github.com/wcamarao/pmx](https://github.com/wcamarao/pmx)
Explicit data mapping and scanning library for Go structs and slices.
### [github.com/stephenafamo/scan](https://github.com/stephenafamo/scan)
Type safe and flexible package for scanning database data into Go types.
Supports, structs, maps, slices and custom mapping functions.
### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx)
Code first migration library for native pgx (no database/sql abstraction).
### [github.com/amirsalarsafaei/sqlc-pgx-monitoring](https://github.com/amirsalarsafaei/sqlc-pgx-monitoring)
A database monitoring/metrics library for pgx and sqlc. Trace, log and monitor your sqlc query performance using OpenTelemetry.
### [https://github.com/nikolayk812/pgx-outbox](https://github.com/nikolayk812/pgx-outbox)
Simple Golang implementation for transactional outbox pattern for PostgreSQL using jackc/pgx driver.
### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy)
Simplifies working with the pgx library, providing convenient scanning of nested structures.

View File

@ -2,7 +2,7 @@ require "erb"
rule '.go' => '.go.erb' do |task| rule '.go' => '.go.erb' do |task|
erb = ERB.new(File.read(task.source)) erb = ERB.new(File.read(task.source))
File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding)) File.write(task.name, "// Code generated from #{task.source}. DO NOT EDIT.\n\n" + erb.result(binding))
sh "goimports", "-w", task.name sh "goimports", "-w", task.name
end end

110
batch.go
View File

@ -10,9 +10,9 @@ import (
// QueuedQuery is a query that has been queued for execution via a Batch. // QueuedQuery is a query that has been queued for execution via a Batch.
type QueuedQuery struct { type QueuedQuery struct {
query string SQL string
arguments []any Arguments []any
fn batchItemFunc Fn batchItemFunc
sd *pgconn.StatementDescription sd *pgconn.StatementDescription
} }
@ -20,14 +20,11 @@ type batchItemFunc func(br BatchResults) error
// Query sets fn to be called when the response to qq is received. // Query sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) Query(fn func(rows Rows) error) { func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
qq.fn = func(br BatchResults) error { qq.Fn = func(br BatchResults) error {
rows, err := br.Query() rows, _ := br.Query()
if err != nil {
return err
}
defer rows.Close() defer rows.Close()
err = fn(rows) err := fn(rows)
if err != nil { if err != nil {
return err return err
} }
@ -39,7 +36,7 @@ func (qq *QueuedQuery) Query(fn func(rows Rows) error) {
// Query sets fn to be called when the response to qq is received. // Query sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) QueryRow(fn func(row Row) error) { func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
qq.fn = func(br BatchResults) error { qq.Fn = func(br BatchResults) error {
row := br.QueryRow() row := br.QueryRow()
return fn(row) return fn(row)
} }
@ -47,7 +44,7 @@ func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
// Exec sets fn to be called when the response to qq is received. // Exec sets fn to be called when the response to qq is received.
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) { func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
qq.fn = func(br BatchResults) error { qq.Fn = func(br BatchResults) error {
ct, err := br.Exec() ct, err := br.Exec()
if err != nil { if err != nil {
return err return err
@ -60,22 +57,28 @@ func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
// Batch queries are a way of bundling multiple queries together to avoid // Batch queries are a way of bundling multiple queries together to avoid
// unnecessary network round trips. A Batch must only be sent once. // unnecessary network round trips. A Batch must only be sent once.
type Batch struct { type Batch struct {
queuedQueries []*QueuedQuery QueuedQueries []*QueuedQuery
} }
// Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. // Queue queues a query to batch b. query can be an SQL query or the name of a prepared statement. The only pgx option
// argument that is supported is QueryRewriter. Queries are executed using the connection's DefaultQueryExecMode.
//
// While query can contain multiple statements if the connection's DefaultQueryExecMode is QueryModeSimple, this should
// be avoided. QueuedQuery.Fn must not be set as it will only be called for the first query. That is, QueuedQuery.Query,
// QueuedQuery.QueryRow, and QueuedQuery.Exec must not be called. In addition, any error messages or tracing that
// include the current query may reference the wrong query.
func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery { func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery {
qq := &QueuedQuery{ qq := &QueuedQuery{
query: query, SQL: query,
arguments: arguments, Arguments: arguments,
} }
b.queuedQueries = append(b.queuedQueries, qq) b.QueuedQueries = append(b.QueuedQueries, qq)
return qq return qq
} }
// Len returns number of queries that have been queued so far. // Len returns number of queries that have been queued so far.
func (b *Batch) Len() int { func (b *Batch) Len() int {
return len(b.queuedQueries) return len(b.QueuedQueries)
} }
type BatchResults interface { type BatchResults interface {
@ -129,7 +132,7 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {
if !br.mrr.NextResult() { if !br.mrr.NextResult() {
err := br.mrr.Close() err := br.mrr.Close()
if err == nil { if err == nil {
err = errors.New("no result") err = errors.New("no more results in batch")
} }
if br.conn.batchTracer != nil { if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
@ -142,7 +145,10 @@ func (br *batchResults) Exec() (pgconn.CommandTag, error) {
} }
commandTag, err := br.mrr.ResultReader().Close() commandTag, err := br.mrr.ResultReader().Close()
if err != nil {
br.err = err br.err = err
br.mrr.Close()
}
if br.conn.batchTracer != nil { if br.conn.batchTracer != nil {
br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{ br.conn.batchTracer.TraceBatchQuery(br.ctx, br.conn, TraceBatchQueryData{
@ -178,7 +184,7 @@ func (br *batchResults) Query() (Rows, error) {
if !br.mrr.NextResult() { if !br.mrr.NextResult() {
rows.err = br.mrr.Close() rows.err = br.mrr.Close()
if rows.err == nil { if rows.err == nil {
rows.err = errors.New("no result") rows.err = errors.New("no more results in batch")
} }
rows.closed = true rows.closed = true
@ -225,10 +231,10 @@ func (br *batchResults) Close() error {
} }
// Read and run fn for all remaining items // Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.queuedQueries[br.qqIdx].fn != nil { if br.b.QueuedQueries[br.qqIdx].Fn != nil {
err := br.b.queuedQueries[br.qqIdx].fn(br) err := br.b.QueuedQueries[br.qqIdx].Fn(br)
if err != nil && br.err == nil { if err != nil {
br.err = err br.err = err
} }
} else { } else {
@ -251,10 +257,10 @@ func (br *batchResults) earlyError() error {
} }
func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) { func (br *batchResults) nextQueryAndArgs() (query string, args []any, ok bool) {
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) { if br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
bi := br.b.queuedQueries[br.qqIdx] bi := br.b.QueuedQueries[br.qqIdx]
query = bi.query query = bi.SQL
args = bi.arguments args = bi.Arguments
ok = true ok = true
br.qqIdx++ br.qqIdx++
} }
@ -285,12 +291,15 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
return pgconn.CommandTag{}, br.err return pgconn.CommandTag{}, br.err
} }
query, arguments, _ := br.nextQueryAndArgs() query, arguments, err := br.nextQueryAndArgs()
if err != nil {
return pgconn.CommandTag{}, err
}
results, err := br.pipeline.GetResults() results, err := br.pipeline.GetResults()
if err != nil { if err != nil {
br.err = err br.err = err
return pgconn.CommandTag{}, err return pgconn.CommandTag{}, br.err
} }
var commandTag pgconn.CommandTag var commandTag pgconn.CommandTag
switch results := results.(type) { switch results := results.(type) {
@ -309,7 +318,7 @@ func (br *pipelineBatchResults) Exec() (pgconn.CommandTag, error) {
}) })
} }
return commandTag, err return commandTag, br.err
} }
// Query reads the results from the next query in the batch as if the query has been sent with Query. // Query reads the results from the next query in the batch as if the query has been sent with Query.
@ -328,9 +337,9 @@ func (br *pipelineBatchResults) Query() (Rows, error) {
return &baseRows{err: br.err, closed: true}, br.err return &baseRows{err: br.err, closed: true}, br.err
} }
query, arguments, ok := br.nextQueryAndArgs() query, arguments, err := br.nextQueryAndArgs()
if !ok { if err != nil {
query = "batch query" return &baseRows{err: err, closed: true}, err
} }
rows := br.conn.getRows(br.ctx, query, arguments) rows := br.conn.getRows(br.ctx, query, arguments)
@ -384,24 +393,20 @@ func (br *pipelineBatchResults) Close() error {
} }
}() }()
if br.err != nil { if br.err == nil && br.lastRows != nil && br.lastRows.err != nil {
return br.err
}
if br.lastRows != nil && br.lastRows.err != nil {
br.err = br.lastRows.err br.err = br.lastRows.err
return br.err return br.err
} }
if br.closed { if br.closed {
return nil return br.err
} }
// Read and run fn for all remaining items // Read and run fn for all remaining items
for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.queuedQueries) { for br.err == nil && !br.closed && br.b != nil && br.qqIdx < len(br.b.QueuedQueries) {
if br.b.queuedQueries[br.qqIdx].fn != nil { if br.b.QueuedQueries[br.qqIdx].Fn != nil {
err := br.b.queuedQueries[br.qqIdx].fn(br) err := br.b.QueuedQueries[br.qqIdx].Fn(br)
if err != nil && br.err == nil { if err != nil {
br.err = err br.err = err
} }
} else { } else {
@ -423,13 +428,16 @@ func (br *pipelineBatchResults) earlyError() error {
return br.err return br.err
} }
func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, ok bool) { func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, err error) {
if br.b != nil && br.qqIdx < len(br.b.queuedQueries) { if br.b == nil {
bi := br.b.queuedQueries[br.qqIdx] return "", nil, errors.New("no reference to batch")
query = bi.query }
args = bi.arguments
ok = true if br.qqIdx >= len(br.b.QueuedQueries) {
return "", nil, errors.New("no more results in batch")
}
bi := br.b.QueuedQueries[br.qqIdx]
br.qqIdx++ br.qqIdx++
} return bi.SQL, bi.Arguments, nil
return
} }

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"os" "os"
"testing" "testing"
"time"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
@ -17,7 +18,10 @@ import (
func TestConnSendBatch(t *testing.T) { func TestConnSendBatch(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test") pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
sql := `create temporary table ledger( sql := `create temporary table ledger(
@ -36,7 +40,7 @@ func TestConnSendBatch(t *testing.T) {
batch.Queue("select * from ledger where false") batch.Queue("select * from ledger where false")
batch.Queue("select sum(amount) from ledger") batch.Queue("select sum(amount) from ledger")
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(ctx, batch)
ct, err := br.Exec() ct, err := br.Exec()
if err != nil { if err != nil {
@ -152,7 +156,10 @@ func TestConnSendBatch(t *testing.T) {
func TestConnSendBatchQueuedQuery(t *testing.T) { func TestConnSendBatchQueuedQuery(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test") pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
sql := `create temporary table ledger( sql := `create temporary table ledger(
@ -237,7 +244,7 @@ func TestConnSendBatchQueuedQuery(t *testing.T) {
return nil return nil
}) })
err := conn.SendBatch(context.Background(), batch).Close() err := conn.SendBatch(ctx, batch).Close()
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }
@ -245,7 +252,10 @@ func TestConnSendBatchQueuedQuery(t *testing.T) {
func TestConnSendBatchMany(t *testing.T) { func TestConnSendBatchMany(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger( sql := `create temporary table ledger(
id serial primary key, id serial primary key,
description varchar not null, description varchar not null,
@ -262,7 +272,7 @@ func TestConnSendBatchMany(t *testing.T) {
} }
batch.Queue("select count(*) from ledger") batch.Queue("select count(*) from ledger")
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(ctx, batch)
for i := 0; i < numInserts; i++ { for i := 0; i < numInserts; i++ {
ct, err := br.Exec() ct, err := br.Exec()
@ -280,6 +290,45 @@ func TestConnSendBatchMany(t *testing.T) {
}) })
} }
// https://github.com/jackc/pgx/issues/1801#issuecomment-2203784178
func TestConnSendBatchReadResultsWhenNothingQueued(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
br := conn.SendBatch(ctx, batch)
commandTag, err := br.Exec()
require.Equal(t, "", commandTag.String())
require.EqualError(t, err, "no more results in batch")
err = br.Close()
require.NoError(t, err)
})
}
func TestConnSendBatchReadMoreResultsThanQueriesSent(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select 1")
br := conn.SendBatch(ctx, batch)
commandTag, err := br.Exec()
require.Equal(t, "SELECT 1", commandTag.String())
require.NoError(t, err)
commandTag, err = br.Exec()
require.Equal(t, "", commandTag.String())
require.EqualError(t, err, "no more results in batch")
err = br.Close()
require.NoError(t, err)
})
}
func TestConnSendBatchWithPreparedStatement(t *testing.T) { func TestConnSendBatchWithPreparedStatement(t *testing.T) {
t.Parallel() t.Parallel()
@ -290,9 +339,12 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) {
pgx.QueryExecModeExec, pgx.QueryExecModeExec,
// Don't test simple mode with prepared statements. // Don't test simple mode with prepared statements.
} }
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, modes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, modes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
_, err := conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n") _, err := conn.Prepare(ctx, "ps1", "select n from generate_series(0,$1::int) n")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -304,7 +356,7 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) {
batch.Queue("ps1", 5) batch.Queue("ps1", 5)
} }
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(ctx, batch)
for i := 0; i < queryCount; i++ { for i := 0; i < queryCount; i++ {
rows, err := br.Query() rows, err := br.Query()
@ -337,13 +389,16 @@ func TestConnSendBatchWithPreparedStatement(t *testing.T) {
func TestConnSendBatchWithQueryRewriter(t *testing.T) { func TestConnSendBatchWithQueryRewriter(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("something to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{1}}) batch.Queue("something to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{1}})
batch.Queue("something else to be replaced", &testQueryRewriter{sql: "select $1::text", args: []any{"hello"}}) batch.Queue("something else to be replaced", &testQueryRewriter{sql: "select $1::text", args: []any{"hello"}})
batch.Queue("more to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{3}}) batch.Queue("more to be replaced", &testQueryRewriter{sql: "select $1::int", args: []any{3}})
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(ctx, batch)
var n int32 var n int32
err := br.QueryRow().Scan(&n) err := br.QueryRow().Scan(&n)
@ -368,6 +423,9 @@ func TestConnSendBatchWithQueryRewriter(t *testing.T) {
func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) { func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err) require.NoError(t, err)
@ -380,7 +438,7 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.
pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)") pgxtest.SkipCockroachDB(t, conn, "Server issues incorrect ParameterDescription (https://github.com/cockroachdb/cockroach/issues/60907)")
_, err = conn.Prepare(context.Background(), "ps1", "select n from generate_series(0,$1::int) n") _, err = conn.Prepare(ctx, "ps1", "select n from generate_series(0,$1::int) n")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -392,7 +450,7 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.
batch.Queue("ps1", 5) batch.Queue("ps1", 5)
} }
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(ctx, batch)
for i := 0; i < queryCount; i++ { for i := 0; i < queryCount; i++ {
rows, err := br.Query() rows, err := br.Query()
@ -426,13 +484,16 @@ func TestConnSendBatchWithPreparedStatementAndStatementCacheDisabled(t *testing.
func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) { func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("select n from generate_series(0,5) n") batch.Queue("select n from generate_series(0,5) n")
batch.Queue("select n from generate_series(0,5) n") batch.Queue("select n from generate_series(0,5) n")
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(ctx, batch)
rows, err := br.Query() rows, err := br.Query()
if err != nil { if err != nil {
@ -485,13 +546,16 @@ func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
func TestConnSendBatchQueryError(t *testing.T) { func TestConnSendBatchQueryError(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0") batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0")
batch.Queue("select n from generate_series(0,5) n") batch.Queue("select n from generate_series(0,5) n")
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(ctx, batch)
rows, err := br.Query() rows, err := br.Query()
if err != nil { if err != nil {
@ -523,12 +587,15 @@ func TestConnSendBatchQueryError(t *testing.T) {
func TestConnSendBatchQuerySyntaxError(t *testing.T) { func TestConnSendBatchQuerySyntaxError(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("select 1 1") batch.Queue("select 1 1")
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(ctx, batch)
var n int32 var n int32
err := br.QueryRow().Scan(&n) err := br.QueryRow().Scan(&n)
@ -547,7 +614,10 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) {
func TestConnSendBatchQueryRowInsert(t *testing.T) { func TestConnSendBatchQueryRowInsert(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger( sql := `create temporary table ledger(
id serial primary key, id serial primary key,
@ -560,7 +630,7 @@ func TestConnSendBatchQueryRowInsert(t *testing.T) {
batch.Queue("select 1") batch.Queue("select 1")
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(ctx, batch)
var value int var value int
err := br.QueryRow().Scan(&value) err := br.QueryRow().Scan(&value)
@ -584,7 +654,10 @@ func TestConnSendBatchQueryRowInsert(t *testing.T) {
func TestConnSendBatchQueryPartialReadInsert(t *testing.T) { func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger( sql := `create temporary table ledger(
id serial primary key, id serial primary key,
@ -597,7 +670,7 @@ func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
batch.Queue("select 1 union all select 2 union all select 3") batch.Queue("select 1 union all select 2 union all select 3")
batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1) batch.Queue("insert into ledger(description, amount) values($1, $2),($1, $2)", "q1", 1)
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(ctx, batch)
rows, err := br.Query() rows, err := br.Query()
if err != nil { if err != nil {
@ -621,7 +694,10 @@ func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
func TestTxSendBatch(t *testing.T) { func TestTxSendBatch(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger1( sql := `create temporary table ledger1(
id serial primary key, id serial primary key,
@ -635,7 +711,7 @@ func TestTxSendBatch(t *testing.T) {
);` );`
mustExec(t, conn, sql) mustExec(t, conn, sql)
tx, _ := conn.Begin(context.Background()) tx, _ := conn.Begin(ctx)
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("insert into ledger1(description) values($1) returning id", "q1") batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
@ -652,7 +728,7 @@ func TestTxSendBatch(t *testing.T) {
batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2) batch.Queue("insert into ledger2(id,amount) values($1, $2)", id, 2)
batch.Queue("select amount from ledger2 where id = $1", id) batch.Queue("select amount from ledger2 where id = $1", id)
br = tx.SendBatch(context.Background(), batch) br = tx.SendBatch(ctx, batch)
ct, err := br.Exec() ct, err := br.Exec()
if err != nil { if err != nil {
@ -669,10 +745,10 @@ func TestTxSendBatch(t *testing.T) {
} }
br.Close() br.Close()
tx.Commit(context.Background()) tx.Commit(ctx)
var count int var count int
conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id).Scan(&count) conn.QueryRow(ctx, "select count(1) from ledger1 where id = $1", id).Scan(&count)
if count != 1 { if count != 1 {
t.Errorf("count => %v, want %v", count, 1) t.Errorf("count => %v, want %v", count, 1)
} }
@ -688,7 +764,10 @@ func TestTxSendBatch(t *testing.T) {
func TestTxSendBatchRollback(t *testing.T) { func TestTxSendBatchRollback(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger1( sql := `create temporary table ledger1(
id serial primary key, id serial primary key,
@ -696,11 +775,11 @@ func TestTxSendBatchRollback(t *testing.T) {
);` );`
mustExec(t, conn, sql) mustExec(t, conn, sql)
tx, _ := conn.Begin(context.Background()) tx, _ := conn.Begin(ctx)
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("insert into ledger1(description) values($1) returning id", "q1") batch.Queue("insert into ledger1(description) values($1) returning id", "q1")
br := tx.SendBatch(context.Background(), batch) br := tx.SendBatch(ctx, batch)
var id int var id int
err := br.QueryRow().Scan(&id) err := br.QueryRow().Scan(&id)
@ -708,9 +787,9 @@ func TestTxSendBatchRollback(t *testing.T) {
t.Error(err) t.Error(err)
} }
br.Close() br.Close()
tx.Rollback(context.Background()) tx.Rollback(ctx)
row := conn.QueryRow(context.Background(), "select count(1) from ledger1 where id = $1", id) row := conn.QueryRow(ctx, "select count(1) from ledger1 where id = $1", id)
var count int var count int
row.Scan(&count) row.Scan(&count)
if count != 0 { if count != 0 {
@ -720,10 +799,62 @@ func TestTxSendBatchRollback(t *testing.T) {
}) })
} }
// https://github.com/jackc/pgx/issues/1578
func TestSendBatchErrorWhileReadingResultsWithoutCallback(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select 4 / $1::int", 0)
batchResult := conn.SendBatch(ctx, batch)
_, execErr := batchResult.Exec()
require.Error(t, execErr)
closeErr := batchResult.Close()
require.Equal(t, execErr, closeErr)
// Try to use the connection.
_, err := conn.Exec(ctx, "select 1")
require.NoError(t, err)
})
}
func TestSendBatchErrorWhileReadingResultsWithExecWhereSomeRowsAreReturned(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{}
batch.Queue("select 4 / n from generate_series(-2, 2) n")
batchResult := conn.SendBatch(ctx, batch)
_, execErr := batchResult.Exec()
require.Error(t, execErr)
closeErr := batchResult.Close()
require.Equal(t, execErr, closeErr)
// Try to use the connection.
_, err := conn.Exec(ctx, "select 1")
require.NoError(t, err)
})
}
func TestConnBeginBatchDeferredError(t *testing.T) { func TestConnBeginBatchDeferredError(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
@ -739,7 +870,7 @@ func TestConnBeginBatchDeferredError(t *testing.T) {
batch.Queue(`update t set n=n+1 where id='b' returning *`) batch.Queue(`update t set n=n+1 where id='b' returning *`)
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(ctx, batch)
rows, err := br.Query() rows, err := br.Query()
if err != nil { if err != nil {
@ -768,6 +899,9 @@ func TestConnBeginBatchDeferredError(t *testing.T) {
} }
func TestConnSendBatchNoStatementCache(t *testing.T) { func TestConnSendBatchNoStatementCache(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec config.DefaultQueryExecMode = pgx.QueryExecModeDescribeExec
config.StatementCacheCapacity = 0 config.StatementCacheCapacity = 0
@ -776,10 +910,13 @@ func TestConnSendBatchNoStatementCache(t *testing.T) {
conn := mustConnect(t, config) conn := mustConnect(t, config)
defer closeConn(t, conn) defer closeConn(t, conn)
testConnSendBatch(t, conn, 3) testConnSendBatch(t, ctx, conn, 3)
} }
func TestConnSendBatchPrepareStatementCache(t *testing.T) { func TestConnSendBatchPrepareStatementCache(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
config.StatementCacheCapacity = 32 config.StatementCacheCapacity = 32
@ -787,10 +924,13 @@ func TestConnSendBatchPrepareStatementCache(t *testing.T) {
conn := mustConnect(t, config) conn := mustConnect(t, config)
defer closeConn(t, conn) defer closeConn(t, conn)
testConnSendBatch(t, conn, 3) testConnSendBatch(t, ctx, conn, 3)
} }
func TestConnSendBatchDescribeStatementCache(t *testing.T) { func TestConnSendBatchDescribeStatementCache(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe config.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
config.DescriptionCacheCapacity = 32 config.DescriptionCacheCapacity = 32
@ -798,16 +938,16 @@ func TestConnSendBatchDescribeStatementCache(t *testing.T) {
conn := mustConnect(t, config) conn := mustConnect(t, config)
defer closeConn(t, conn) defer closeConn(t, conn)
testConnSendBatch(t, conn, 3) testConnSendBatch(t, ctx, conn, 3)
} }
func testConnSendBatch(t *testing.T, conn *pgx.Conn, queryCount int) { func testConnSendBatch(t *testing.T, ctx context.Context, conn *pgx.Conn, queryCount int) {
batch := &pgx.Batch{} batch := &pgx.Batch{}
for j := 0; j < queryCount; j++ { for j := 0; j < queryCount; j++ {
batch.Queue("select n from generate_series(0,5) n") batch.Queue("select n from generate_series(0,5) n")
} }
br := conn.SendBatch(context.Background(), batch) br := conn.SendBatch(ctx, batch)
for j := 0; j < queryCount; j++ { for j := 0; j < queryCount; j++ {
rows, err := br.Query() rows, err := br.Query()
@ -830,12 +970,12 @@ func testConnSendBatch(t *testing.T, conn *pgx.Conn, queryCount int) {
func TestSendBatchSimpleProtocol(t *testing.T) { func TestSendBatchSimpleProtocol(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) config := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol config.DefaultQueryExecMode = pgx.QueryExecModeSimpleProtocol
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
conn := mustConnect(t, config) conn := mustConnect(t, config)
defer closeConn(t, conn) defer closeConn(t, conn)
@ -868,8 +1008,41 @@ func TestSendBatchSimpleProtocol(t *testing.T) {
assert.False(t, rows.Next()) assert.False(t, rows.Next())
} }
// https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887
func TestConnSendBatchErrorDoesNotLeaveOrphanedPreparedStatement(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
mustExec(t, conn, `create temporary table foo(col1 text primary key);`)
batch := &pgx.Batch{}
batch.Queue("select col1 from foo")
batch.Queue("select col1 from baz")
err := conn.SendBatch(ctx, batch).Close()
require.EqualError(t, err, `ERROR: relation "baz" does not exist (SQLSTATE 42P01)`)
mustExec(t, conn, `create temporary table baz(col1 text primary key);`)
// Since table baz now exists, the batch should succeed.
batch = &pgx.Batch{}
batch.Queue("select col1 from foo")
batch.Queue("select col1 from baz")
err = conn.SendBatch(ctx, batch).Close()
require.NoError(t, err)
})
}
func ExampleConn_SendBatch() { func ExampleConn_SendBatch() {
conn, err := pgx.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
if err != nil { if err != nil {
fmt.Printf("Unable to establish connection: %v", err) fmt.Printf("Unable to establish connection: %v", err)
return return
@ -912,7 +1085,7 @@ func ExampleConn_SendBatch() {
return err return err
}) })
err = conn.SendBatch(context.Background(), batch).Close() err = conn.SendBatch(ctx, batch).Close()
if err != nil { if err != nil {
fmt.Printf("SendBatch error: %v", err) fmt.Printf("SendBatch error: %v", err)
return return

View File

@ -13,7 +13,6 @@ import (
"time" "time"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/internal/nbconn"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -152,7 +151,7 @@ func BenchmarkMinimalPgConnPreparedSelect(b *testing.B) {
for rr.NextRow() { for rr.NextRow() {
for i := range rr.Values() { for i := range rr.Values() {
if bytes.Compare(rr.Values()[0], encodedBytes) != 0 { if !bytes.Equal(rr.Values()[0], encodedBytes) {
b.Fatalf("unexpected values: %s %s", rr.Values()[i], encodedBytes) b.Fatalf("unexpected values: %s %s", rr.Values()[i], encodedBytes)
} }
} }
@ -340,8 +339,9 @@ type benchmarkWriteTableCopyFromSrc struct {
} }
func (s *benchmarkWriteTableCopyFromSrc) Next() bool { func (s *benchmarkWriteTableCopyFromSrc) Next() bool {
next := s.idx < s.count
s.idx++ s.idx++
return s.idx < s.count return next
} }
func (s *benchmarkWriteTableCopyFromSrc) Values() ([]any, error) { func (s *benchmarkWriteTableCopyFromSrc) Values() ([]any, error) {
@ -407,6 +407,34 @@ func benchmarkWriteNRowsViaInsert(b *testing.B, n int) {
} }
} }
func benchmarkWriteNRowsViaBatchInsert(b *testing.B, n int) {
conn := mustConnect(b, mustParseConfig(b, os.Getenv("PGX_TEST_DATABASE")))
defer closeConn(b, conn)
mustExec(b, conn, benchmarkWriteTableCreateSQL)
_, err := conn.Prepare(context.Background(), "insert_t", benchmarkWriteTableInsertSQL)
if err != nil {
b.Fatal(err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
src := newBenchmarkWriteTableCopyFromSrc(n)
batch := &pgx.Batch{}
for src.Next() {
values, _ := src.Values()
batch.Queue("insert_t", values...)
}
err = conn.SendBatch(context.Background(), batch).Close()
if err != nil {
b.Fatal(err)
}
}
}
type queryArgs []any type queryArgs []any
func (qa *queryArgs) Append(v any) string { func (qa *queryArgs) Append(v any) string {
@ -484,7 +512,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc
} }
if err := tx.Commit(context.Background()); err != nil { if err := tx.Commit(context.Background()); err != nil {
return 0, nil return 0, err
} }
return rowCount, nil return rowCount, nil
@ -560,6 +588,22 @@ func benchmarkWriteNRowsViaCopy(b *testing.B, n int) {
} }
} }
func BenchmarkWrite2RowsViaInsert(b *testing.B) {
benchmarkWriteNRowsViaInsert(b, 2)
}
func BenchmarkWrite2RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 2)
}
func BenchmarkWrite2RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 2)
}
func BenchmarkWrite2RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 2)
}
func BenchmarkWrite5RowsViaInsert(b *testing.B) { func BenchmarkWrite5RowsViaInsert(b *testing.B) {
benchmarkWriteNRowsViaInsert(b, 5) benchmarkWriteNRowsViaInsert(b, 5)
} }
@ -567,6 +611,9 @@ func BenchmarkWrite5RowsViaInsert(b *testing.B) {
func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) { func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 5) benchmarkWriteNRowsViaMultiInsert(b, 5)
} }
func BenchmarkWrite5RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 5)
}
func BenchmarkWrite5RowsViaCopy(b *testing.B) { func BenchmarkWrite5RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 5) benchmarkWriteNRowsViaCopy(b, 5)
@ -579,6 +626,9 @@ func BenchmarkWrite10RowsViaInsert(b *testing.B) {
func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) { func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 10) benchmarkWriteNRowsViaMultiInsert(b, 10)
} }
func BenchmarkWrite10RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 10)
}
func BenchmarkWrite10RowsViaCopy(b *testing.B) { func BenchmarkWrite10RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 10) benchmarkWriteNRowsViaCopy(b, 10)
@ -591,6 +641,9 @@ func BenchmarkWrite100RowsViaInsert(b *testing.B) {
func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) { func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 100) benchmarkWriteNRowsViaMultiInsert(b, 100)
} }
func BenchmarkWrite100RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 100)
}
func BenchmarkWrite100RowsViaCopy(b *testing.B) { func BenchmarkWrite100RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 100) benchmarkWriteNRowsViaCopy(b, 100)
@ -604,6 +657,10 @@ func BenchmarkWrite1000RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 1000) benchmarkWriteNRowsViaMultiInsert(b, 1000)
} }
func BenchmarkWrite1000RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 1000)
}
func BenchmarkWrite1000RowsViaCopy(b *testing.B) { func BenchmarkWrite1000RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 1000) benchmarkWriteNRowsViaCopy(b, 1000)
} }
@ -615,6 +672,9 @@ func BenchmarkWrite10000RowsViaInsert(b *testing.B) {
func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) { func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 10000) benchmarkWriteNRowsViaMultiInsert(b, 10000)
} }
func BenchmarkWrite10000RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 10000)
}
func BenchmarkWrite10000RowsViaCopy(b *testing.B) { func BenchmarkWrite10000RowsViaCopy(b *testing.B) {
benchmarkWriteNRowsViaCopy(b, 10000) benchmarkWriteNRowsViaCopy(b, 10000)
@ -884,6 +944,7 @@ type BenchRowSimple struct {
BirthDate time.Time BirthDate time.Time
Weight int32 Weight int32
Height int32 Height int32
Tags []string
UpdateTime time.Time UpdateTime time.Time
} }
@ -897,13 +958,13 @@ func BenchmarkSelectRowsScanSimple(b *testing.B) {
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
br := &BenchRowSimple{} br := &BenchRowSimple{}
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
for rows.Next() { for rows.Next() {
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime) rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime)
} }
if rows.Err() != nil { if rows.Err() != nil {
@ -922,6 +983,7 @@ type BenchRowStringBytes struct {
BirthDate time.Time BirthDate time.Time
Weight int32 Weight int32
Height int32 Height int32
Tags []string
UpdateTime time.Time UpdateTime time.Time
} }
@ -935,13 +997,13 @@ func BenchmarkSelectRowsScanStringBytes(b *testing.B) {
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
br := &BenchRowStringBytes{} br := &BenchRowStringBytes{}
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount) rows, err := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
for rows.Next() { for rows.Next() {
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime) rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime)
} }
if rows.Err() != nil { if rows.Err() != nil {
@ -960,6 +1022,7 @@ type BenchRowDecoder struct {
BirthDate pgtype.Date BirthDate pgtype.Date
Weight pgtype.Int4 Weight pgtype.Int4
Height pgtype.Int4 Height pgtype.Int4
Tags pgtype.FlatArray[string]
UpdateTime pgtype.Timestamptz UpdateTime pgtype.Timestamptz
} }
@ -985,7 +1048,7 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
rows, err := conn.Query( rows, err := conn.Query(
context.Background(), context.Background(),
"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
pgx.QueryResultFormats{format.code}, pgx.QueryResultFormats{format.code},
rowCount, rowCount,
) )
@ -994,7 +1057,7 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) {
} }
for rows.Next() { for rows.Next() {
rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.UpdateTime) rows.Scan(&br.ID, &br.FirstName, &br.LastName, &br.Sex, &br.BirthDate, &br.Weight, &br.Height, &br.Tags, &br.UpdateTime)
} }
if rows.Err() != nil { if rows.Err() != nil {
@ -1016,7 +1079,7 @@ func BenchmarkSelectRowsPgConnExecText(b *testing.B) {
for _, rowCount := range rowCounts { for _, rowCount := range rowCounts {
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) { b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
mrr := conn.PgConn().Exec(context.Background(), fmt.Sprintf("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + %d) n", rowCount)) mrr := conn.PgConn().Exec(context.Background(), fmt.Sprintf("select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + %d) n", rowCount))
for mrr.NextResult() { for mrr.NextResult() {
rr := mrr.ResultReader() rr := mrr.ResultReader()
for rr.NextRow() { for rr.NextRow() {
@ -1053,11 +1116,11 @@ func BenchmarkSelectRowsPgConnExecParams(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
rr := conn.PgConn().ExecParams( rr := conn.PgConn().ExecParams(
context.Background(), context.Background(),
"select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n",
[][]byte{[]byte(strconv.FormatInt(rowCount, 10))}, [][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
nil, nil,
nil, nil,
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code}, []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code},
) )
for rr.NextRow() { for rr.NextRow() {
rr.Values() rr.Values()
@ -1074,13 +1137,107 @@ func BenchmarkSelectRowsPgConnExecParams(b *testing.B) {
} }
} }
func BenchmarkSelectRowsSimpleCollectRowsRowToStructByPos(b *testing.B) {
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(b, conn)
rowCounts := getSelectRowsCounts(b)
for _, rowCount := range rowCounts {
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
for i := 0; i < b.N; i++ {
rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
benchRows, err := pgx.CollectRows(rows, pgx.RowToStructByPos[BenchRowSimple])
if err != nil {
b.Fatal(err)
}
if len(benchRows) != int(rowCount) {
b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows))
}
}
})
}
}
func BenchmarkSelectRowsSimpleAppendRowsRowToStructByPos(b *testing.B) {
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(b, conn)
rowCounts := getSelectRowsCounts(b)
for _, rowCount := range rowCounts {
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
benchRows := make([]BenchRowSimple, 0, rowCount)
for i := 0; i < b.N; i++ {
benchRows = benchRows[:0]
rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
var err error
benchRows, err = pgx.AppendRows(benchRows, rows, pgx.RowToStructByPos[BenchRowSimple])
if err != nil {
b.Fatal(err)
}
if len(benchRows) != int(rowCount) {
b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows))
}
}
})
}
}
func BenchmarkSelectRowsSimpleCollectRowsRowToStructByName(b *testing.B) {
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(b, conn)
rowCounts := getSelectRowsCounts(b)
for _, rowCount := range rowCounts {
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
for i := 0; i < b.N; i++ {
rows, _ := conn.Query(context.Background(), "select n as id, 'Adam' as first_name, 'Smith ' || n as last_name, 'male' as sex, '1952-06-16'::date as birth_date, 258 as weight, 72 as height, '{foo,bar,baz}'::text[] as tags, '2001-01-28 01:02:03-05'::timestamptz as update_time from generate_series(100001, 100000 + $1) n", rowCount)
benchRows, err := pgx.CollectRows(rows, pgx.RowToStructByName[BenchRowSimple])
if err != nil {
b.Fatal(err)
}
if len(benchRows) != int(rowCount) {
b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows))
}
}
})
}
}
func BenchmarkSelectRowsSimpleAppendRowsRowToStructByName(b *testing.B) {
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(b, conn)
rowCounts := getSelectRowsCounts(b)
for _, rowCount := range rowCounts {
b.Run(fmt.Sprintf("%d rows", rowCount), func(b *testing.B) {
benchRows := make([]BenchRowSimple, 0, rowCount)
for i := 0; i < b.N; i++ {
benchRows = benchRows[:0]
rows, _ := conn.Query(context.Background(), "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", rowCount)
var err error
benchRows, err = pgx.AppendRows(benchRows, rows, pgx.RowToStructByPos[BenchRowSimple])
if err != nil {
b.Fatal(err)
}
if len(benchRows) != int(rowCount) {
b.Fatalf("Expected %d rows, got %d", rowCount, len(benchRows))
}
}
})
}
}
func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) { func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(b, conn) defer closeConn(b, conn)
rowCounts := getSelectRowsCounts(b) rowCounts := getSelectRowsCounts(b)
_, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil) _, err := conn.PgConn().Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -1102,7 +1259,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
"ps1", "ps1",
[][]byte{[]byte(strconv.FormatInt(rowCount, 10))}, [][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
nil, nil,
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code}, []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code},
) )
for rr.NextRow() { for rr.NextRow() {
rr.Values() rr.Values()
@ -1120,7 +1277,7 @@ func BenchmarkSelectRowsPgConnExecPrepared(b *testing.B) {
} }
type queryRecorder struct { type queryRecorder struct {
conn nbconn.Conn conn net.Conn
writeBuf []byte writeBuf []byte
readCount int readCount int
} }
@ -1136,14 +1293,6 @@ func (qr *queryRecorder) Write(b []byte) (n int, err error) {
return qr.conn.Write(b) return qr.conn.Write(b)
} }
func (qr *queryRecorder) BufferReadUntilBlock() error {
return qr.conn.BufferReadUntilBlock()
}
func (qr *queryRecorder) Flush() error {
return qr.conn.Flush()
}
func (qr *queryRecorder) Close() error { func (qr *queryRecorder) Close() error {
return qr.conn.Close() return qr.conn.Close()
} }
@ -1189,7 +1338,7 @@ func BenchmarkSelectRowsRawPrepared(b *testing.B) {
conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")).PgConn() conn := mustConnectString(b, os.Getenv("PGX_TEST_DATABASE")).PgConn()
defer conn.Close(context.Background()) defer conn.Close(context.Background())
_, err := conn.Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil) _, err := conn.Prepare(context.Background(), "ps1", "select n, 'Adam', 'Smith ' || n, 'male', '1952-06-16'::date, 258, 72, '{foo,bar,baz}'::text[], '2001-01-28 01:02:03-05'::timestamptz from generate_series(100001, 100000 + $1) n", nil)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -1212,7 +1361,7 @@ func BenchmarkSelectRowsRawPrepared(b *testing.B) {
"ps1", "ps1",
[][]byte{[]byte(strconv.FormatInt(rowCount, 10))}, [][]byte{[]byte(strconv.FormatInt(rowCount, 10))},
nil, nil,
[]int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code}, []int16{format.code, pgx.TextFormatCode, pgx.TextFormatCode, pgx.TextFormatCode, format.code, format.code, format.code, format.code, format.code},
) )
_, err := rr.Close() _, err := rr.Close()
require.NoError(b, err) require.NoError(b, err)

View File

@ -9,40 +9,41 @@ then
sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list"
sudo apt-get update -qq sudo apt-get update -qq
sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION
sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf sudo cp testsetup/pg_hba.conf /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf sudo sh -c "echo \"listen_addresses = '127.0.0.1'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf"
echo "host all pgx_md5 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf sudo sh -c "cat testsetup/postgresql_ssl.conf >> /etc/postgresql/$PGVERSION/main/postgresql.conf"
echo "host all pgx_pw 127.0.0.1/32 password" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "hostssl all pgx_ssl 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf cd testsetup
echo "host replication pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf
echo "host pgx_test pgx_replication 127.0.0.1/32 md5" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf # Generate CA, server, and encrypted client certificates.
sudo chmod 777 /etc/postgresql/$PGVERSION/main/postgresql.conf go run generate_certs.go
if $(dpkg --compare-versions $PGVERSION ge 9.6) ; then
echo "wal_level='logical'" >> /etc/postgresql/$PGVERSION/main/postgresql.conf # Copy certificates to server directory and set permissions.
echo "max_wal_senders=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf sudo cp ca.pem /var/lib/postgresql/$PGVERSION/main/root.crt
echo "max_replication_slots=5" >> /etc/postgresql/$PGVERSION/main/postgresql.conf sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/root.crt
fi sudo cp localhost.key /var/lib/postgresql/$PGVERSION/main/server.key
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.key
sudo chmod 600 /var/lib/postgresql/$PGVERSION/main/server.key
sudo cp localhost.crt /var/lib/postgresql/$PGVERSION/main/server.crt
sudo chown postgres:postgres /var/lib/postgresql/$PGVERSION/main/server.crt
cp ca.pem /tmp
cp pgx_sslcert.key /tmp
cp pgx_sslcert.crt /tmp
cd ..
sudo /etc/init.d/postgresql restart sudo /etc/init.d/postgresql restart
psql -U postgres -c 'create database pgx_test' createdb -U postgres pgx_test
psql -U postgres pgx_test -c 'create extension hstore' psql -U postgres -f testsetup/postgresql_setup.sql pgx_test
psql -U postgres pgx_test -c 'create domain uint64 as numeric(20,0)'
psql -U postgres -c "create user pgx_ssl SUPERUSER PASSWORD 'secret'"
psql -U postgres -c "create user pgx_md5 SUPERUSER PASSWORD 'secret'"
psql -U postgres -c "create user pgx_pw SUPERUSER PASSWORD 'secret'"
psql -U postgres -c "create user `whoami`"
psql -U postgres -c "create user pgx_replication with replication password 'secret'"
# The tricky test user, below, has to actually exist so that it can be used in a test
# of aclitem formatting. It turns out aclitems cannot contain non-existing users/roles.
psql -U postgres -c "create user \" tricky, ' } \"\" \\ test user \" superuser password 'secret'"
fi fi
if [[ "${PGVERSION-}" =~ ^cockroach ]] if [[ "${PGVERSION-}" =~ ^cockroach ]]
then then
wget -qO- https://binaries.cockroachdb.com/cockroach-v20.2.5.linux-amd64.tgz | tar xvz wget -qO- https://binaries.cockroachdb.com/cockroach-v24.3.3.linux-amd64.tgz | tar xvz
sudo mv cockroach-v20.2.5.linux-amd64/cockroach /usr/local/bin/ sudo mv cockroach-v24.3.3.linux-amd64/cockroach /usr/local/bin/
cockroach start-single-node --insecure --background --listen-addr=localhost cockroach start-single-node --insecure --background --listen-addr=localhost
cockroach sql --insecure -e 'create database pgx_test' cockroach sql --insecure -e 'create database pgx_test'
fi fi

528
conn.go
View File

@ -2,13 +2,15 @@ package pgx
import ( import (
"context" "context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/internal/sanitize" "github.com/jackc/pgx/v5/internal/sanitize"
"github.com/jackc/pgx/v5/internal/stmtcache" "github.com/jackc/pgx/v5/internal/stmtcache"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
@ -35,13 +37,18 @@ type ConnConfig struct {
// DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol // DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol
// and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as // and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as
// PGBouncer. In this case it may be preferrable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same // PGBouncer. In this case it may be preferable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same
// functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument. // functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument.
DefaultQueryExecMode QueryExecMode DefaultQueryExecMode QueryExecMode
createdByParseConfig bool // Used to enforce created by ParseConfig rule. createdByParseConfig bool // Used to enforce created by ParseConfig rule.
} }
// ParseConfigOptions contains options that control how a config is built such as getsslpassword.
type ParseConfigOptions struct {
pgconn.ParseConfigOptions
}
// Copy returns a deep copy of the config that is safe to use and modify. // Copy returns a deep copy of the config that is safe to use and modify.
// The only exception is the tls.Config: // The only exception is the tls.Config:
// according to the tls.Config docs it must not be modified after creation. // according to the tls.Config docs it must not be modified after creation.
@ -94,11 +101,33 @@ func (ident Identifier) Sanitize() string {
return strings.Join(parts, ".") return strings.Join(parts, ".")
} }
var (
// ErrNoRows occurs when rows are expected but none are returned. // ErrNoRows occurs when rows are expected but none are returned.
var ErrNoRows = errors.New("no rows in result set") ErrNoRows = newProxyErr(sql.ErrNoRows, "no rows in result set")
// ErrTooManyRows occurs when more rows than expected are returned.
ErrTooManyRows = errors.New("too many rows in result set")
)
var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache") func newProxyErr(background error, msg string) error {
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache") return &proxyError{
msg: msg,
background: background,
}
}
type proxyError struct {
msg string
background error
}
func (err *proxyError) Error() string { return err.msg }
func (err *proxyError) Unwrap() error { return err.background }
var (
errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
)
// Connect establishes a connection with a PostgreSQL server with a connection string. See // Connect establishes a connection with a PostgreSQL server with a connection string. See
// pgconn.Connect for details. // pgconn.Connect for details.
@ -110,6 +139,16 @@ func Connect(ctx context.Context, connString string) (*Conn, error) {
return connect(ctx, connConfig) return connect(ctx, connConfig)
} }
// ConnectWithOptions behaves exactly like Connect with the addition of options. At the present options is only used to
// provide a GetSSLPassword function.
func ConnectWithOptions(ctx context.Context, connString string, options ParseConfigOptions) (*Conn, error) {
connConfig, err := ParseConfigWithOptions(connString, options)
if err != nil {
return nil, err
}
return connect(ctx, connConfig)
}
// ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct. // ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct.
// connConfig must have been created by ParseConfig. // connConfig must have been created by ParseConfig.
func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
@ -120,22 +159,10 @@ func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
return connect(ctx, connConfig) return connect(ctx, connConfig)
} }
// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig // ParseConfigWithOptions behaves exactly as ParseConfig does with the addition of options. At the present options is
// does. In addition, it accepts the following options: // only used to provide a GetSSLPassword function.
// func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*ConnConfig, error) {
// default_query_exec_mode config, err := pgconn.ParseConfigWithOptions(connString, options.ParseConfigOptions)
// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See
// QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement".
//
// statement_cache_capacity
// The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode.
// Default: 512.
//
// description_cache_capacity
// The maximum size of the description cache used when executing a query with "cache_describe" query exec mode.
// Default: 512.
func ParseConfig(connString string) (*ConnConfig, error) {
config, err := pgconn.ParseConfig(connString)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -145,7 +172,7 @@ func ParseConfig(connString string) (*ConnConfig, error) {
delete(config.RuntimeParams, "statement_cache_capacity") delete(config.RuntimeParams, "statement_cache_capacity")
n, err := strconv.ParseInt(s, 10, 32) n, err := strconv.ParseInt(s, 10, 32)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err) return nil, pgconn.NewParseConfigError(connString, "cannot parse statement_cache_capacity", err)
} }
statementCacheCapacity = int(n) statementCacheCapacity = int(n)
} }
@ -155,7 +182,7 @@ func ParseConfig(connString string) (*ConnConfig, error) {
delete(config.RuntimeParams, "description_cache_capacity") delete(config.RuntimeParams, "description_cache_capacity")
n, err := strconv.ParseInt(s, 10, 32) n, err := strconv.ParseInt(s, 10, 32)
if err != nil { if err != nil {
return nil, fmt.Errorf("cannot parse description_cache_capacity: %w", err) return nil, pgconn.NewParseConfigError(connString, "cannot parse description_cache_capacity", err)
} }
descriptionCacheCapacity = int(n) descriptionCacheCapacity = int(n)
} }
@ -175,7 +202,7 @@ func ParseConfig(connString string) (*ConnConfig, error) {
case "simple_protocol": case "simple_protocol":
defaultQueryExecMode = QueryExecModeSimpleProtocol defaultQueryExecMode = QueryExecModeSimpleProtocol
default: default:
return nil, fmt.Errorf("invalid default_query_exec_mode: %v", err) return nil, pgconn.NewParseConfigError(connString, "invalid default_query_exec_mode", err)
} }
} }
@ -191,6 +218,24 @@ func ParseConfig(connString string) (*ConnConfig, error) {
return connConfig, nil return connConfig, nil
} }
// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that [pgconn.ParseConfig]
// does. In addition, it accepts the following options:
//
// - default_query_exec_mode.
// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See
// QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement".
//
// - statement_cache_capacity.
// The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode.
// Default: 512.
//
// - description_cache_capacity.
// The maximum size of the description cache used when executing a query with "cache_describe" query exec mode.
// Default: 512.
func ParseConfig(connString string) (*ConnConfig, error) {
return ParseConfigWithOptions(connString, ParseConfigOptions{})
}
// connect connects to a database. connect takes ownership of config. The caller must not use or access it again. // connect connects to a database. connect takes ownership of config. The caller must not use or access it again.
func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
if connectTracer, ok := config.Tracer.(ConnectTracer); ok { if connectTracer, ok := config.Tracer.(ConnectTracer); ok {
@ -248,7 +293,7 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
return c, nil return c, nil
} }
// Close closes a connection. It is safe to call Close on a already closed // Close closes a connection. It is safe to call Close on an already closed
// connection. // connection.
func (c *Conn) Close(ctx context.Context) error { func (c *Conn) Close(ctx context.Context) error {
if c.IsClosed() { if c.IsClosed() {
@ -259,12 +304,15 @@ func (c *Conn) Close(ctx context.Context) error {
return err return err
} }
// Prepare creates a prepared statement with name and sql. sql can contain placeholders // Prepare creates a prepared statement with name and sql. sql can contain placeholders for bound parameters. These
// for bound parameters. These placeholders are referenced positional as $1, $2, etc. // placeholders are referenced positionally as $1, $2, etc. name can be used instead of sql with Query, QueryRow, and
// Exec to execute the statement. It can also be used with Batch.Queue.
// //
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same // The underlying PostgreSQL identifier for the prepared statement will be name if name != sql or a digest of sql if
// name and sql arguments. This allows a code path to Prepare and Query/Exec without // name == sql.
// concern for if the statement has already been prepared. //
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same name and sql arguments. This
// allows a code path to Prepare and Query/Exec without concern for if the statement has already been prepared.
func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) { func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) {
if c.prepareTracer != nil { if c.prepareTracer != nil {
ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql}) ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql})
@ -286,22 +334,60 @@ func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.Statem
}() }()
} }
sd, err = c.pgConn.Prepare(ctx, name, sql, nil) var psName, psKey string
if name == sql {
digest := sha256.Sum256([]byte(sql))
psName = "stmt_" + hex.EncodeToString(digest[0:24])
psKey = sql
} else {
psName = name
psKey = name
}
sd, err = c.pgConn.Prepare(ctx, psName, sql, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if name != "" { if psKey != "" {
c.preparedStatements[name] = sd c.preparedStatements[psKey] = sd
} }
return sd, nil return sd, nil
} }
// Deallocate released a prepared statement // Deallocate releases a prepared statement. Calling Deallocate on a non-existent prepared statement will succeed.
func (c *Conn) Deallocate(ctx context.Context, name string) error { func (c *Conn) Deallocate(ctx context.Context, name string) error {
var psName string
sd := c.preparedStatements[name]
if sd != nil {
psName = sd.Name
} else {
psName = name
}
err := c.pgConn.Deallocate(ctx, psName)
if err != nil {
return err
}
if sd != nil {
delete(c.preparedStatements, name) delete(c.preparedStatements, name)
_, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll() }
return nil
}
// DeallocateAll releases all previously prepared statements from the server and client, where it also resets the statement and description cache.
func (c *Conn) DeallocateAll(ctx context.Context) error {
c.preparedStatements = map[string]*pgconn.StatementDescription{}
if c.config.StatementCacheCapacity > 0 {
c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity)
}
if c.config.DescriptionCacheCapacity > 0 {
c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
}
_, err := c.pgConn.Exec(ctx, "deallocate all").ReadAll()
return err return err
} }
@ -334,7 +420,7 @@ func (c *Conn) IsClosed() bool {
return c.pgConn.IsClosed() return c.pgConn.IsClosed()
} }
func (c *Conn) die(err error) { func (c *Conn) die() {
if c.IsClosed() { if c.IsClosed() {
return return
} }
@ -348,11 +434,9 @@ func quoteIdentifier(s string) string {
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"` return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
} }
// Ping executes an empty sql statement against the *Conn // Ping delegates to the underlying *pgconn.PgConn.Ping.
// If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned.
func (c *Conn) Ping(ctx context.Context) error { func (c *Conn) Ping(ctx context.Context) error {
_, err := c.Exec(ctx, ";") return c.pgConn.Ping(ctx)
return err
} }
// PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the // PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the
@ -407,7 +491,10 @@ optionLoop:
} }
if queryRewriter != nil { if queryRewriter != nil {
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments) sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
if err != nil {
return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %w", err)
}
} }
// Always use simple protocol when there are no arguments. // Always use simple protocol when there are no arguments.
@ -426,7 +513,7 @@ optionLoop:
} }
sd := c.statementCache.Get(sql) sd := c.statementCache.Get(sql)
if sd == nil { if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql) sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql)
if err != nil { if err != nil {
return pgconn.CommandTag{}, err return pgconn.CommandTag{}, err
} }
@ -444,6 +531,7 @@ optionLoop:
if err != nil { if err != nil {
return pgconn.CommandTag{}, err return pgconn.CommandTag{}, err
} }
c.descriptionCache.Put(sd)
} }
return c.execParams(ctx, sd, arguments) return c.execParams(ctx, sd, arguments)
@ -472,7 +560,7 @@ func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []a
mrr := c.pgConn.Exec(ctx, sql) mrr := c.pgConn.Exec(ctx, sql)
for mrr.NextResult() { for mrr.NextResult() {
commandTag, err = mrr.ResultReader().Close() commandTag, _ = mrr.ResultReader().Close()
} }
err = mrr.Close() err = mrr.Close()
return commandTag, err return commandTag, err
@ -500,14 +588,6 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription
return result.CommandTag, result.Err return result.CommandTag, result.Err
} }
type unknownArgumentTypeQueryExecModeExecError struct {
arg any
}
func (e *unknownArgumentTypeQueryExecModeExecError) Error() string {
return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg)
}
func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) { func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) {
err := c.eqb.Build(c.typeMap, nil, args) err := c.eqb.Build(c.typeMap, nil, args)
if err != nil { if err != nil {
@ -538,40 +618,57 @@ type QueryExecMode int32
const ( const (
_ QueryExecMode = iota _ QueryExecMode = iota
// Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single // Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single round
// round trip after the statement is cached. This is the default. // trip after the statement is cached. This is the default. If the database schema is modified or the search_path is
// changed after a statement is cached then the first execution of a previously cached query may fail. e.g. If the
// number of columns returned by a "SELECT *" changes or the type of a column is changed.
QueryExecModeCacheStatement QueryExecModeCacheStatement
// Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the // Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the extended
// extended protocol. Queries are executed in a single round trip after the description is cached. If the database // protocol. Queries are executed in a single round trip after the description is cached. If the database schema is
// schema is modified or the search_path is changed this may result in undetected result decoding errors. // modified or the search_path is changed after a statement is cached then the first execution of a previously cached
// query may fail. e.g. If the number of columns returned by a "SELECT *" changes or the type of a column is changed.
QueryExecModeCacheDescribe QueryExecModeCacheDescribe
// Get the statement description on every execution. This uses the extended protocol. Queries require two round trips // Get the statement description on every execution. This uses the extended protocol. Queries require two round trips
// to execute. It does not use prepared statements (allowing usage with most connection poolers) and is safe even // to execute. It does not use named prepared statements. But it does use the unnamed prepared statement to get the
// when the the database schema is modified concurrently. // statement description on the first round trip and then uses it to execute the query on the second round trip. This
// may cause problems with connection poolers that switch the underlying connection between round trips. It is safe
// even when the database schema is modified concurrently.
QueryExecModeDescribeExec QueryExecModeDescribeExec
// Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol // Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol
// with text formatted parameters and results. Queries are executed in a single round trip. Type mappings can be // with text formatted parameters and results. Queries are executed in a single round trip. Type mappings can be
// registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are // registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are
// unregistered or ambigious. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know // unregistered or ambiguous. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know
// the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot. // the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot.
//
// On rare occasions user defined types may behave differently when encoded in the text format instead of the binary
// format. For example, this could happen if a "type RomanNumeral int32" implements fmt.Stringer to format integers as
// Roman numerals (e.g. 7 is VII). The binary format would properly encode the integer 7 as the binary value for 7.
// But the text format would encode the integer 7 as the string "VII". As QueryExecModeExec uses the text format, it
// is possible that changing query mode from another mode to QueryExecModeExec could change the behavior of the query.
// This should not occur with types pgx supports directly and can be avoided by registering the types with
// pgtype.Map.RegisterDefaultPgType and implementing the appropriate type interfaces. In the cas of RomanNumeral, it
// should implement pgtype.Int64Valuer.
QueryExecModeExec QueryExecModeExec
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. // Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. This is
// Queries are executed in a single round trip. Type mappings can be registered with // especially significant for []byte values. []byte values are encoded as PostgreSQL bytea. string must be used
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious. // instead for text type values including json and jsonb. Type mappings can be registered with
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use // pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambiguous.
// a map[string]string directly as an argument. This mode cannot. // e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use a
// map[string]string directly as an argument. This mode cannot. Queries are executed in a single round trip.
// //
// QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec with minor // QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec. This includes
// exceptions such as behavior when multiple result returning queries are erroneously sent in a single string. // the warning regarding differences in text format and binary format encoding with user defined types. There may be
// other minor exceptions such as behavior when multiple result returning queries are erroneously sent in a single
// string.
// //
// QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer // QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer
// QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol // QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol should
// should only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does // only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does not
// not support the extended protocol. // support the extended protocol.
QueryExecModeSimpleProtocol QueryExecModeSimpleProtocol
) )
@ -600,7 +697,7 @@ type QueryResultFormatsByOID map[uint32]int16
// QueryRewriter rewrites a query when used as the first arguments to a query method. // QueryRewriter rewrites a query when used as the first arguments to a query method.
type QueryRewriter interface { type QueryRewriter interface {
RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error)
} }
// Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query // Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query
@ -611,6 +708,9 @@ type QueryRewriter interface {
// returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It // returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It
// is allowed to ignore the error returned from Query and handle it in Rows. // is allowed to ignore the error returned from Query and handle it in Rows.
// //
// It is possible for a call of FieldDescriptions on the returned Rows to return nil even if the Query call did not
// return an error.
//
// It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be // It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be
// collected before processing rather than processed while receiving each row. This avoids the possibility of the // collected before processing rather than processed while receiving each row. This avoids the possibility of the
// application processing rows from a query that the server rejected. The CollectRows function is useful here. // application processing rows from a query that the server rejected. The CollectRows function is useful here.
@ -659,7 +759,16 @@ optionLoop:
} }
if queryRewriter != nil { if queryRewriter != nil {
sql, args = queryRewriter.RewriteQuery(ctx, c, sql, args) var err error
originalSQL := sql
originalArgs := args
sql, args, err = queryRewriter.RewriteQuery(ctx, c, sql, args)
if err != nil {
rows := c.getRows(ctx, originalSQL, originalArgs)
err = fmt.Errorf("rewrite query failed: %w", err)
rows.fatal(err)
return rows, err
}
} }
// Bypass any statement caching. // Bypass any statement caching.
@ -668,51 +777,17 @@ optionLoop:
} }
c.eqb.reset() c.eqb.reset()
anynil.NormalizeSlice(args)
rows := c.getRows(ctx, sql, args) rows := c.getRows(ctx, sql, args)
var err error var err error
sd, explicitPreparedStatement := c.preparedStatements[sql] sd, explicitPreparedStatement := c.preparedStatements[sql]
if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec { if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec {
if sd == nil { if sd == nil {
switch mode { sd, err = c.getStatementDescription(ctx, mode, sql)
case QueryExecModeCacheStatement:
if c.statementCache == nil {
err = errDisabledStatementCache
rows.fatal(err)
return rows, err
}
sd = c.statementCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
if err != nil { if err != nil {
rows.fatal(err) rows.fatal(err)
return rows, err return rows, err
} }
c.statementCache.Put(sd)
}
case QueryExecModeCacheDescribe:
if c.descriptionCache == nil {
err = errDisabledDescriptionCache
rows.fatal(err)
return rows, err
}
sd = c.descriptionCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
rows.fatal(err)
return rows, err
}
c.descriptionCache.Put(sd)
}
case QueryExecModeDescribeExec:
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
rows.fatal(err)
return rows, err
}
}
} }
if len(sd.ParamOIDs) != len(args) { if len(sd.ParamOIDs) != len(args) {
@ -781,6 +856,47 @@ optionLoop:
return rows, rows.err return rows, rows.err
} }
// getStatementDescription returns the statement description of the sql query
// according to the given mode.
//
// If the mode is one that doesn't require to know the param and result OIDs
// then nil is returned without error.
func (c *Conn) getStatementDescription(
ctx context.Context,
mode QueryExecMode,
sql string,
) (sd *pgconn.StatementDescription, err error) {
switch mode {
case QueryExecModeCacheStatement:
if c.statementCache == nil {
return nil, errDisabledStatementCache
}
sd = c.statementCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, stmtcache.StatementName(sql), sql)
if err != nil {
return nil, err
}
c.statementCache.Put(sd)
}
case QueryExecModeCacheDescribe:
if c.descriptionCache == nil {
return nil, errDisabledDescriptionCache
}
sd = c.descriptionCache.Get(sql)
if sd == nil {
sd, err = c.Prepare(ctx, "", sql)
if err != nil {
return nil, err
}
c.descriptionCache.Put(sd)
}
case QueryExecModeDescribeExec:
return c.Prepare(ctx, "", sql)
}
return sd, err
}
// QueryRow is a convenience wrapper over Query. Any error that occurs while // QueryRow is a convenience wrapper over Query. Any error that occurs while
// querying is deferred until calling Scan on the returned Row. That Row will // querying is deferred until calling Scan on the returned Row. That Row will
// error with ErrNoRows if no rows are returned. // error with ErrNoRows if no rows are returned.
@ -792,6 +908,9 @@ func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row {
// SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection // explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
// is used again. // is used again.
//
// Depending on the QueryExecMode, all queries may be prepared before any are executed. This means that creating a table
// and using it in a subsequent query in the same batch can fail.
func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) { func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
if c.batchTracer != nil { if c.batchTracer != nil {
ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b}) ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b})
@ -807,15 +926,14 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
return &batchResults{ctx: ctx, conn: c, err: err} return &batchResults{ctx: ctx, conn: c, err: err}
} }
mode := c.config.DefaultQueryExecMode for _, bi := range b.QueuedQueries {
for _, bi := range b.queuedQueries {
var queryRewriter QueryRewriter var queryRewriter QueryRewriter
sql := bi.query sql := bi.SQL
arguments := bi.arguments arguments := bi.Arguments
optionLoop: optionLoop:
for len(arguments) > 0 { for len(arguments) > 0 {
// Update Batch.Queue function comment when additional options are implemented
switch arg := arguments[0].(type) { switch arg := arguments[0].(type) {
case QueryRewriter: case QueryRewriter:
queryRewriter = arg queryRewriter = arg
@ -826,20 +944,26 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
} }
if queryRewriter != nil { if queryRewriter != nil {
sql, arguments = queryRewriter.RewriteQuery(ctx, c, sql, arguments) var err error
sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("rewrite query failed: %w", err)}
}
} }
bi.query = sql bi.SQL = sql
bi.arguments = arguments bi.Arguments = arguments
} }
// TODO: changing mode per batch? Update Batch.Queue function comment when implemented
mode := c.config.DefaultQueryExecMode
if mode == QueryExecModeSimpleProtocol { if mode == QueryExecModeSimpleProtocol {
return c.sendBatchQueryExecModeSimpleProtocol(ctx, b) return c.sendBatchQueryExecModeSimpleProtocol(ctx, b)
} }
// All other modes use extended protocol and thus can use prepared statements. // All other modes use extended protocol and thus can use prepared statements.
for _, bi := range b.queuedQueries { for _, bi := range b.QueuedQueries {
if sd, ok := c.preparedStatements[bi.query]; ok { if sd, ok := c.preparedStatements[bi.SQL]; ok {
bi.sd = sd bi.sd = sd
} }
} }
@ -860,11 +984,11 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults { func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
var sb strings.Builder var sb strings.Builder
for i, bi := range b.queuedQueries { for i, bi := range b.QueuedQueries {
if i > 0 { if i > 0 {
sb.WriteByte(';') sb.WriteByte(';')
} }
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...) sql, err := c.sanitizeForSimpleQuery(bi.SQL, bi.Arguments...)
if err != nil { if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err} return &batchResults{ctx: ctx, conn: c, err: err}
} }
@ -883,21 +1007,21 @@ func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batc
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults { func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
batch := &pgconn.Batch{} batch := &pgconn.Batch{}
for _, bi := range b.queuedQueries { for _, bi := range b.QueuedQueries {
sd := bi.sd sd := bi.sd
if sd != nil { if sd != nil {
err := c.eqb.Build(c.typeMap, sd, bi.arguments) err := c.eqb.Build(c.typeMap, sd, bi.Arguments)
if err != nil { if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err} return &batchResults{ctx: ctx, conn: c, err: err}
} }
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats) batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
} else { } else {
err := c.eqb.Build(c.typeMap, nil, bi.arguments) err := c.eqb.Build(c.typeMap, nil, bi.Arguments)
if err != nil { if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err} return &batchResults{ctx: ctx, conn: c, err: err}
} }
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats) batch.ExecParams(bi.SQL, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
} }
} }
@ -916,24 +1040,24 @@ func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchR
func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
if c.statementCache == nil { if c.statementCache == nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache} return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache, closed: true}
} }
distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int) distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.queuedQueries { for _, bi := range b.QueuedQueries {
if bi.sd == nil { if bi.sd == nil {
sd := c.statementCache.Get(bi.query) sd := c.statementCache.Get(bi.SQL)
if sd != nil { if sd != nil {
bi.sd = sd bi.sd = sd
} else { } else {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present { if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
bi.sd = distinctNewQueries[idx] bi.sd = distinctNewQueries[idx]
} else { } else {
sd = &pgconn.StatementDescription{ sd = &pgconn.StatementDescription{
Name: stmtcache.NextStatementName(), Name: stmtcache.StatementName(bi.SQL),
SQL: bi.query, SQL: bi.SQL,
} }
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd) distinctNewQueries = append(distinctNewQueries, sd)
@ -948,23 +1072,23 @@ func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batc
func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) { func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
if c.descriptionCache == nil { if c.descriptionCache == nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache} return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache, closed: true}
} }
distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int) distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.queuedQueries { for _, bi := range b.QueuedQueries {
if bi.sd == nil { if bi.sd == nil {
sd := c.descriptionCache.Get(bi.query) sd := c.descriptionCache.Get(bi.SQL)
if sd != nil { if sd != nil {
bi.sd = sd bi.sd = sd
} else { } else {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present { if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
bi.sd = distinctNewQueries[idx] bi.sd = distinctNewQueries[idx]
} else { } else {
sd = &pgconn.StatementDescription{ sd = &pgconn.StatementDescription{
SQL: bi.query, SQL: bi.SQL,
} }
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd) distinctNewQueries = append(distinctNewQueries, sd)
@ -981,13 +1105,13 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch)
distinctNewQueries := []*pgconn.StatementDescription{} distinctNewQueries := []*pgconn.StatementDescription{}
distinctNewQueriesIdxMap := make(map[string]int) distinctNewQueriesIdxMap := make(map[string]int)
for _, bi := range b.queuedQueries { for _, bi := range b.QueuedQueries {
if bi.sd == nil { if bi.sd == nil {
if idx, present := distinctNewQueriesIdxMap[bi.query]; present { if idx, present := distinctNewQueriesIdxMap[bi.SQL]; present {
bi.sd = distinctNewQueries[idx] bi.sd = distinctNewQueries[idx]
} else { } else {
sd := &pgconn.StatementDescription{ sd := &pgconn.StatementDescription{
SQL: bi.query, SQL: bi.SQL,
} }
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries) distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
distinctNewQueries = append(distinctNewQueries, sd) distinctNewQueries = append(distinctNewQueries, sd)
@ -1000,33 +1124,51 @@ func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch)
} }
func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) { func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) {
pipeline := c.pgConn.StartPipeline(context.Background()) pipeline := c.pgConn.StartPipeline(ctx)
defer func() { defer func() {
if pbr.err != nil { if pbr != nil && pbr.err != nil {
pipeline.Close() pipeline.Close()
} }
}() }()
// Prepare any needed queries // Prepare any needed queries
if len(distinctNewQueries) > 0 { if len(distinctNewQueries) > 0 {
err := func() (err error) {
for _, sd := range distinctNewQueries { for _, sd := range distinctNewQueries {
pipeline.SendPrepare(sd.Name, sd.SQL, nil) pipeline.SendPrepare(sd.Name, sd.SQL, nil)
} }
err := pipeline.Sync() // Store all statements we are preparing into the cache. It's fine if it overflows because HandleInvalidated will
// clean them up later.
if sdCache != nil {
for _, sd := range distinctNewQueries {
sdCache.Put(sd)
}
}
// If something goes wrong preparing the statements, we need to invalidate the cache entries we just added.
defer func() {
if err != nil && sdCache != nil {
for _, sd := range distinctNewQueries {
sdCache.Invalidate(sd.SQL)
}
}
}()
err = pipeline.Sync()
if err != nil { if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err} return err
} }
for _, sd := range distinctNewQueries { for _, sd := range distinctNewQueries {
results, err := pipeline.GetResults() results, err := pipeline.GetResults()
if err != nil { if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err} return err
} }
resultSD, ok := results.(*pgconn.StatementDescription) resultSD, ok := results.(*pgconn.StatementDescription)
if !ok { if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)} return fmt.Errorf("expected statement description, got %T", results)
} }
// Fill in the previously empty / pending statement descriptions. // Fill in the previously empty / pending statement descriptions.
@ -1036,27 +1178,28 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
results, err := pipeline.GetResults() results, err := pipeline.GetResults()
if err != nil { if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err} return err
} }
_, ok := results.(*pgconn.PipelineSync) _, ok := results.(*pgconn.PipelineSync)
if !ok { if !ok {
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)} return fmt.Errorf("expected sync, got %T", results)
}
} }
// Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later. return nil
if sdCache != nil { }()
for _, sd := range distinctNewQueries { if err != nil {
sdCache.Put(sd) return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
} }
} }
// Queue the queries. // Queue the queries.
for _, bi := range b.queuedQueries { for _, bi := range b.QueuedQueries {
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments) err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments)
if err != nil { if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err} // we wrap the error so we the user can understand which query failed inside the batch
err = fmt.Errorf("error building query %s: %w", bi.SQL, err)
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
} }
if bi.sd.Name == "" { if bi.sd.Name == "" {
@ -1068,7 +1211,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
err := pipeline.Sync() err := pipeline.Sync()
if err != nil { if err != nil {
return &pipelineBatchResults{ctx: ctx, conn: c, err: err} return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
} }
return &pipelineBatchResults{ return &pipelineBatchResults{
@ -1100,7 +1243,15 @@ func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
return sanitize.SanitizeSQL(sql, valueArgs...) return sanitize.SanitizeSQL(sql, valueArgs...)
} }
// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. // LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration. typeName must be
// the name of a type where the underlying type(s) is already understood by pgx. It is for derived types. In particular,
// typeName must be one of the following:
// - An array type name of a type that is already registered. e.g. "_foo" when "foo" is registered.
// - A composite type name where all field types are already registered.
// - A domain type name where the base type is already registered.
// - An enum type name.
// - A range type name where the element type is already registered.
// - A multirange type name where the element type is already registered.
func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) { func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) {
var oid uint32 var oid uint32
@ -1110,8 +1261,9 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err
} }
var typtype string var typtype string
var typbasetype uint32
err = c.QueryRow(ctx, "select typtype::text from pg_type where oid=$1", oid).Scan(&typtype) err = c.QueryRow(ctx, "select typtype::text, typbasetype from pg_type where oid=$1", oid).Scan(&typtype, &typbasetype)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1136,8 +1288,39 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err
} }
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil
case "d": // domain
dt, ok := c.TypeMap().TypeForOID(typbasetype)
if !ok {
return nil, errors.New("domain base type OID not registered")
}
return &pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec}, nil
case "e": // enum case "e": // enum
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
case "r": // range
elementOID, err := c.getRangeElementOID(ctx, oid)
if err != nil {
return nil, err
}
dt, ok := c.TypeMap().TypeForOID(elementOID)
if !ok {
return nil, errors.New("range element OID not registered")
}
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.RangeCodec{ElementType: dt}}, nil
case "m": // multirange
elementOID, err := c.getMultiRangeElementOID(ctx, oid)
if err != nil {
return nil, err
}
dt, ok := c.TypeMap().TypeForOID(elementOID)
if !ok {
return nil, errors.New("multirange element OID not registered")
}
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}, nil
default: default:
return &pgtype.Type{}, errors.New("unknown typtype") return &pgtype.Type{}, errors.New("unknown typtype")
} }
@ -1154,6 +1337,28 @@ func (c *Conn) getArrayElementOID(ctx context.Context, oid uint32) (uint32, erro
return typelem, nil return typelem, nil
} }
func (c *Conn) getRangeElementOID(ctx context.Context, oid uint32) (uint32, error) {
var typelem uint32
err := c.QueryRow(ctx, "select rngsubtype from pg_range where rngtypid=$1", oid).Scan(&typelem)
if err != nil {
return 0, err
}
return typelem, nil
}
func (c *Conn) getMultiRangeElementOID(ctx context.Context, oid uint32) (uint32, error) {
var typelem uint32
err := c.QueryRow(ctx, "select rngtypid from pg_range where rngmultitypid=$1", oid).Scan(&typelem)
if err != nil {
return 0, err
}
return typelem, nil
}
func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) { func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) {
var typrelid uint32 var typrelid uint32
@ -1168,6 +1373,8 @@ func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.Com
rows, _ := c.Query(ctx, `select attname, atttypid rows, _ := c.Query(ctx, `select attname, atttypid
from pg_attribute from pg_attribute
where attrelid=$1 where attrelid=$1
and not attisdropped
and attnum > 0
order by attnum`, order by attnum`,
typrelid, typrelid,
) )
@ -1187,17 +1394,17 @@ order by attnum`,
} }
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error { func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
if c.pgConn.TxStatus() != 'I' { if txStatus := c.pgConn.TxStatus(); txStatus != 'I' && txStatus != 'T' {
return nil return nil
} }
if c.descriptionCache != nil { if c.descriptionCache != nil {
c.descriptionCache.HandleInvalidated() c.descriptionCache.RemoveInvalidated()
} }
var invalidatedStatements []*pgconn.StatementDescription var invalidatedStatements []*pgconn.StatementDescription
if c.statementCache != nil { if c.statementCache != nil {
invalidatedStatements = c.statementCache.HandleInvalidated() invalidatedStatements = c.statementCache.GetInvalidated()
} }
if len(invalidatedStatements) == 0 { if len(invalidatedStatements) == 0 {
@ -1221,5 +1428,10 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
return fmt.Errorf("failed to deallocate cached statement(s): %w", err) return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
} }
c.statementCache.RemoveInvalidated()
for _, sd := range invalidatedStatements {
delete(c.preparedStatements, sd.Name)
}
return nil return nil
} }

55
conn_internal_test.go Normal file
View File

@ -0,0 +1,55 @@
package pgx
import (
"context"
"fmt"
"os"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func mustParseConfig(t testing.TB, connString string) *ConnConfig {
config, err := ParseConfig(connString)
require.Nil(t, err)
return config
}
func mustConnect(t testing.TB, config *ConnConfig) *Conn {
conn, err := ConnectConfig(context.Background(), config)
if err != nil {
t.Fatalf("Unable to establish connection: %v", err)
}
return conn
}
// Ensures the connection limits the size of its cached objects.
// This test examines the internals of *Conn so must be in the same package.
func TestStmtCacheSizeLimit(t *testing.T) {
const cacheLimit = 16
connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
connConfig.StatementCacheCapacity = cacheLimit
conn := mustConnect(t, connConfig)
defer func() {
err := conn.Close(context.Background())
if err != nil {
t.Fatal(err)
}
}()
// run a set of unique queries that should overflow the cache
ctx := context.Background()
for i := 0; i < cacheLimit*2; i++ {
uniqueString := fmt.Sprintf("unique %d", i)
uniqueSQL := fmt.Sprintf("select '%s'", uniqueString)
var output string
err := conn.QueryRow(ctx, uniqueSQL).Scan(&output)
require.NoError(t, err)
require.Equal(t, uniqueString, output)
}
// preparedStatements contains cacheLimit+1 because deallocation happens before the query
assert.Len(t, conn.preparedStatements, cacheLimit+1)
assert.Equal(t, cacheLimit, conn.statementCache.Len())
}

View File

@ -3,6 +3,7 @@ package pgx_test
import ( import (
"bytes" "bytes"
"context" "context"
"database/sql"
"os" "os"
"strings" "strings"
"sync" "sync"
@ -197,10 +198,28 @@ func TestParseConfigExtractsDefaultQueryExecMode(t *testing.T) {
} }
} }
func TestParseConfigErrors(t *testing.T) {
t.Parallel()
for _, tt := range []struct {
connString string
expectedErrSubstring string
}{
{"default_query_exec_mode=does_not_exist", "does_not_exist"},
} {
config, err := pgx.ParseConfig(tt.connString)
require.Nil(t, config)
require.ErrorContains(t, err, tt.expectedErrSubstring)
}
}
func TestExec(t *testing.T) { func TestExec(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results.String() != "CREATE TABLE" { if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); results.String() != "CREATE TABLE" {
t.Error("Unexpected results from Exec") t.Error("Unexpected results from Exec")
} }
@ -236,14 +255,17 @@ type testQueryRewriter struct {
args []any args []any
} }
func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any) { func (qr *testQueryRewriter) RewriteQuery(ctx context.Context, conn *pgx.Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return qr.sql, qr.args return qr.sql, qr.args, nil
} }
func TestExecWithQueryRewriter(t *testing.T) { func TestExecWithQueryRewriter(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
qr := testQueryRewriter{sql: "select $1::int", args: []any{42}} qr := testQueryRewriter{sql: "select $1::int", args: []any{42}}
_, err := conn.Exec(ctx, "should be replaced", &qr) _, err := conn.Exec(ctx, "should be replaced", &qr)
require.NoError(t, err) require.NoError(t, err)
@ -253,7 +275,10 @@ func TestExecWithQueryRewriter(t *testing.T) {
func TestExecFailure(t *testing.T) { func TestExecFailure(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
if _, err := conn.Exec(context.Background(), "selct;"); err == nil { if _, err := conn.Exec(context.Background(), "selct;"); err == nil {
t.Fatal("Expected SQL syntax error") t.Fatal("Expected SQL syntax error")
} }
@ -269,7 +294,10 @@ func TestExecFailure(t *testing.T) {
func TestExecFailureWithArguments(t *testing.T) { func TestExecFailureWithArguments(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(context.Background(), "selct $1;", 1) _, err := conn.Exec(context.Background(), "selct $1;", 1)
if err == nil { if err == nil {
t.Fatal("Expected SQL syntax error") t.Fatal("Expected SQL syntax error")
@ -284,8 +312,11 @@ func TestExecFailureWithArguments(t *testing.T) {
func TestExecContextWithoutCancelation(t *testing.T) { func TestExecContextWithoutCancelation(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
ctx, cancelFunc := context.WithCancel(context.Background()) defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
ctx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc() defer cancelFunc()
commandTag, err := conn.Exec(ctx, "create temporary table foo(id integer primary key);") commandTag, err := conn.Exec(ctx, "create temporary table foo(id integer primary key);")
@ -302,8 +333,11 @@ func TestExecContextWithoutCancelation(t *testing.T) {
func TestExecContextFailureWithoutCancelation(t *testing.T) { func TestExecContextFailureWithoutCancelation(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
ctx, cancelFunc := context.WithCancel(context.Background()) defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
ctx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc() defer cancelFunc()
_, err := conn.Exec(ctx, "selct;") _, err := conn.Exec(ctx, "selct;")
@ -324,8 +358,11 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) {
func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) { func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
ctx, cancelFunc := context.WithCancel(context.Background()) defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
ctx, cancelFunc := context.WithCancel(ctx)
defer cancelFunc() defer cancelFunc()
_, err := conn.Exec(ctx, "selct $1;", 1) _, err := conn.Exec(ctx, "selct $1;", 1)
@ -424,7 +461,7 @@ func TestPrepare(t *testing.T) {
t.Errorf("Prepared statement did not return expected value: %v", s) t.Errorf("Prepared statement did not return expected value: %v", s)
} }
err = conn.Deallocate(context.Background(), "test") err = conn.DeallocateAll(context.Background())
if err != nil { if err != nil {
t.Errorf("conn.Deallocate failed: %v", err) t.Errorf("conn.Deallocate failed: %v", err)
} }
@ -446,9 +483,10 @@ func TestPrepareBadSQLFailure(t *testing.T) {
func TestPrepareIdempotency(t *testing.T) { func TestPrepareIdempotency(t *testing.T) {
t.Parallel() t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer closeConn(t, conn) defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
_, err := conn.Prepare(context.Background(), "test", "select 42::integer") _, err := conn.Prepare(context.Background(), "test", "select 42::integer")
if err != nil { if err != nil {
@ -471,12 +509,16 @@ func TestPrepareIdempotency(t *testing.T) {
t.Fatalf("Prepare statement with same name but different SQL should have failed but it didn't") t.Fatalf("Prepare statement with same name but different SQL should have failed but it didn't")
return return
} }
})
} }
func TestPrepareStatementCacheModes(t *testing.T) { func TestPrepareStatementCacheModes(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Prepare(context.Background(), "test", "select $1::text") _, err := conn.Prepare(context.Background(), "test", "select $1::text")
require.NoError(t, err) require.NoError(t, err)
@ -487,6 +529,91 @@ func TestPrepareStatementCacheModes(t *testing.T) {
}) })
} }
func TestPrepareWithDigestedName(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := "select $1::text"
sd, err := conn.Prepare(ctx, sql, sql)
require.NoError(t, err)
require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name)
var s string
err = conn.QueryRow(ctx, sql, "hello").Scan(&s)
require.NoError(t, err)
require.Equal(t, "hello", s)
err = conn.Deallocate(ctx, sql)
require.NoError(t, err)
})
}
// https://github.com/jackc/pgx/pull/1795
func TestDeallocateInAbortedTransaction(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
tx, err := conn.Begin(ctx)
require.NoError(t, err)
sql := "select $1::text"
sd, err := tx.Prepare(ctx, sql, sql)
require.NoError(t, err)
require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name)
var s string
err = tx.QueryRow(ctx, sql, "hello").Scan(&s)
require.NoError(t, err)
require.Equal(t, "hello", s)
_, err = tx.Exec(ctx, "select 1/0") // abort transaction with divide by zero error
require.Error(t, err)
err = conn.Deallocate(ctx, sql)
require.NoError(t, err)
err = tx.Rollback(ctx)
require.NoError(t, err)
sd, err = conn.Prepare(ctx, sql, sql)
require.NoError(t, err)
require.Equal(t, "stmt_2510cc7db17de3f42758a2a29c8b9ef8305d007b997ebdd6", sd.Name)
})
}
func TestDeallocateMissingPreparedStatementStillClearsFromPreparedStatementMap(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Prepare(ctx, "ps", "select $1::text")
require.NoError(t, err)
_, err = conn.Exec(ctx, "deallocate ps")
require.NoError(t, err)
err = conn.Deallocate(ctx, "ps")
require.NoError(t, err)
_, err = conn.Prepare(ctx, "ps", "select $1::text, $2::text")
require.NoError(t, err)
var s1, s2 string
err = conn.QueryRow(ctx, "ps", "hello", "world").Scan(&s1, &s2)
require.NoError(t, err)
require.Equal(t, "hello", s1)
require.Equal(t, "world", s2)
})
}
func TestListenNotify(t *testing.T) { func TestListenNotify(t *testing.T) {
t.Parallel() t.Parallel()
@ -526,6 +653,7 @@ func TestListenNotify(t *testing.T) {
defer cancel() defer cancel()
notification, err = listener.WaitForNotification(ctx) notification, err = listener.WaitForNotification(ctx)
assert.True(t, pgconn.Timeout(err)) assert.True(t, pgconn.Timeout(err))
assert.Nil(t, notification)
// listener can listen again after a timeout // listener can listen again after a timeout
mustExec(t, notifier, "notify chat") mustExec(t, notifier, "notify chat")
@ -545,6 +673,7 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
listenerDone := make(chan bool) listenerDone := make(chan bool)
notifierDone := make(chan bool) notifierDone := make(chan bool)
listening := make(chan bool)
go func() { go func() {
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) defer closeConn(t, conn)
@ -553,6 +682,7 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
}() }()
mustExec(t, conn, "listen busysafe") mustExec(t, conn, "listen busysafe")
listening <- true
for i := 0; i < 5000; i++ { for i := 0; i < 5000; i++ {
var sum int32 var sum int32
@ -575,7 +705,7 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
} }
if rows.Err() != nil { if rows.Err() != nil {
t.Errorf("conn.Query failed: %v", err) t.Errorf("conn.Query failed: %v", rows.Err())
return return
} }
@ -588,8 +718,6 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
t.Errorf("Wrong number of rows: %v", rowCount) t.Errorf("Wrong number of rows: %v", rowCount)
return return
} }
time.Sleep(1 * time.Microsecond)
} }
}() }()
@ -600,9 +728,10 @@ func TestListenNotifyWhileBusyIsSafe(t *testing.T) {
notifierDone <- true notifierDone <- true
}() }()
<-listening
for i := 0; i < 100000; i++ { for i := 0; i < 100000; i++ {
mustExec(t, conn, "notify busysafe, 'hello'") mustExec(t, conn, "notify busysafe, 'hello'")
time.Sleep(1 * time.Microsecond)
} }
}() }()
@ -715,7 +844,10 @@ func TestFatalTxError(t *testing.T) {
func TestInsertBoolArray(t *testing.T) { func TestInsertBoolArray(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results.String() != "CREATE TABLE" { if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); results.String() != "CREATE TABLE" {
t.Error("Unexpected results from Exec") t.Error("Unexpected results from Exec")
} }
@ -730,7 +862,10 @@ func TestInsertBoolArray(t *testing.T) {
func TestInsertTimestampArray(t *testing.T) { func TestInsertTimestampArray(t *testing.T) {
t.Parallel() t.Parallel()
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results.String() != "CREATE TABLE" { if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); results.String() != "CREATE TABLE" {
t.Error("Unexpected results from Exec") t.Error("Unexpected results from Exec")
} }
@ -812,7 +947,10 @@ func TestConnInitTypeMap(t *testing.T) {
} }
func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) { func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)")
var n uint64 var n uint64
@ -828,7 +966,10 @@ func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) {
} }
func TestDomainType(t *testing.T) { func TestDomainType(t *testing.T) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)") pgxtest.SkipCockroachDB(t, conn, "Server does support domain types (https://github.com/cockroachdb/cockroach/issues/27796)")
// Domain type uint64 is a PostgreSQL domain of underlying type numeric. // Domain type uint64 is a PostgreSQL domain of underlying type numeric.
@ -837,24 +978,21 @@ func TestDomainType(t *testing.T) {
// uint64 but a result OID of the underlying numeric. // uint64 but a result OID of the underlying numeric.
var s string var s string
err := conn.QueryRow(context.Background(), "select $1::uint64", "24").Scan(&s) err := conn.QueryRow(ctx, "select $1::uint64", "24").Scan(&s)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "24", s) require.Equal(t, "24", s)
// Register type // Register type
var uint64OID uint32 uint64Type, err := conn.LoadType(ctx, "uint64")
err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID) require.NoError(t, err)
if err != nil { conn.TypeMap().RegisterType(uint64Type)
t.Fatalf("did not find uint64 OID, %v", err)
}
conn.TypeMap().RegisterType(&pgtype.Type{Name: "uint64", OID: uint64OID, Codec: pgtype.NumericCodec{}})
var n uint64 var n uint64
err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) err = conn.QueryRow(ctx, "select $1::uint64", uint64(24)).Scan(&n)
require.NoError(t, err) require.NoError(t, err)
// String is still an acceptable argument after registration // String is still an acceptable argument after registration
err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n) err = conn.QueryRow(ctx, "select $1::uint64", "7").Scan(&n)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -865,7 +1003,10 @@ func TestDomainType(t *testing.T) {
} }
func TestLoadTypeSameNameInDifferentSchemas(t *testing.T) { func TestLoadTypeSameNameInDifferentSchemas(t *testing.T) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does support composite types (https://github.com/cockroachdb/cockroach/issues/27792)") pgxtest.SkipCockroachDB(t, conn, "Server does support composite types (https://github.com/cockroachdb/cockroach/issues/27792)")
tx, err := conn.Begin(ctx) tx, err := conn.Begin(ctx)
@ -906,6 +1047,111 @@ create type pgx_b.point as (c text);
}) })
} }
func TestLoadCompositeType(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does support composite types (https://github.com/cockroachdb/cockroach/issues/27792)")
tx, err := conn.Begin(ctx)
require.NoError(t, err)
defer tx.Rollback(ctx)
_, err = tx.Exec(ctx, "create type compositetype as (attr1 int, attr2 int)")
require.NoError(t, err)
_, err = tx.Exec(ctx, "alter type compositetype drop attribute attr1")
require.NoError(t, err)
_, err = conn.LoadType(ctx, "compositetype")
require.NoError(t, err)
})
}
func TestLoadRangeType(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does support range types")
tx, err := conn.Begin(ctx)
require.NoError(t, err)
defer tx.Rollback(ctx)
_, err = tx.Exec(ctx, "create type examplefloatrange as range (subtype=float8, subtype_diff=float8mi)")
require.NoError(t, err)
// Register types
newRangeType, err := conn.LoadType(ctx, "examplefloatrange")
require.NoError(t, err)
conn.TypeMap().RegisterType(newRangeType)
conn.TypeMap().RegisterDefaultPgType(pgtype.Range[float64]{}, "examplefloatrange")
var inputRangeType = pgtype.Range[float64]{
Lower: 1.0,
Upper: 2.0,
LowerType: pgtype.Inclusive,
UpperType: pgtype.Inclusive,
Valid: true,
}
var outputRangeType pgtype.Range[float64]
err = tx.QueryRow(ctx, "SELECT $1::examplefloatrange", inputRangeType).Scan(&outputRangeType)
require.NoError(t, err)
require.Equal(t, inputRangeType, outputRangeType)
})
}
func TestLoadMultiRangeType(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does support range types")
pgxtest.SkipPostgreSQLVersionLessThan(t, conn, 14) // multirange data type was added in 14 postgresql
tx, err := conn.Begin(ctx)
require.NoError(t, err)
defer tx.Rollback(ctx)
_, err = tx.Exec(ctx, "create type examplefloatrange as range (subtype=float8, subtype_diff=float8mi, multirange_type_name=examplefloatmultirange)")
require.NoError(t, err)
// Register types
newRangeType, err := conn.LoadType(ctx, "examplefloatrange")
require.NoError(t, err)
conn.TypeMap().RegisterType(newRangeType)
conn.TypeMap().RegisterDefaultPgType(pgtype.Range[float64]{}, "examplefloatrange")
newMultiRangeType, err := conn.LoadType(ctx, "examplefloatmultirange")
require.NoError(t, err)
conn.TypeMap().RegisterType(newMultiRangeType)
conn.TypeMap().RegisterDefaultPgType(pgtype.Multirange[pgtype.Range[float64]]{}, "examplefloatmultirange")
var inputMultiRangeType = pgtype.Multirange[pgtype.Range[float64]]{
{
Lower: 1.0,
Upper: 2.0,
LowerType: pgtype.Inclusive,
UpperType: pgtype.Inclusive,
Valid: true,
},
{
Lower: 3.0,
Upper: 4.0,
LowerType: pgtype.Exclusive,
UpperType: pgtype.Exclusive,
Valid: true,
},
}
var outputMultiRangeType pgtype.Multirange[pgtype.Range[float64]]
err = tx.QueryRow(ctx, "SELECT $1::examplefloatmultirange", inputMultiRangeType).Scan(&outputMultiRangeType)
require.NoError(t, err)
require.Equal(t, inputMultiRangeType, outputMultiRangeType)
})
}
func TestStmtCacheInvalidationConn(t *testing.T) { func TestStmtCacheInvalidationConn(t *testing.T) {
ctx := context.Background() ctx := context.Background()
@ -1048,7 +1294,10 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
} }
func TestInsertDurationInterval(t *testing.T) { func TestInsertDurationInterval(t *testing.T) {
pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(context.Background(), "create temporary table t(duration INTERVAL(0) NOT NULL)") _, err := conn.Exec(context.Background(), "create temporary table t(duration INTERVAL(0) NOT NULL)")
require.NoError(t, err) require.NoError(t, err)
@ -1082,7 +1331,7 @@ func TestRawValuesUnderlyingMemoryReused(t *testing.T) {
rows.Close() rows.Close()
require.NoError(t, rows.Err()) require.NoError(t, rows.Err())
if bytes.Compare(original, buf) != 0 { if !bytes.Equal(original, buf) {
return return
} }
} }
@ -1090,3 +1339,82 @@ func TestRawValuesUnderlyingMemoryReused(t *testing.T) {
t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not") t.Fatal("expected buffer from RawValues to be overwritten by subsequent queries but it was not")
}) })
} }
// https://github.com/jackc/pgx/issues/1847
func TestConnDeallocateInvalidatedCachedStatementsWhenCanceled(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "CockroachDB returns decimal instead of integer for integer division")
var n int32
err := conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 1, n)
// Divide by zero causes an error. baseRows.Close() calls Invalidate on the statement cache whenever an error was
// encountered by the query. Use this to purposely invalidate the query. If we had access to private fields of conn
// we could call conn.statementCache.InvalidateAll() instead.
err = conn.QueryRow(ctx, "select 1 / $1::int", 0).Scan(&n)
require.Error(t, err)
ctx2, cancel2 := context.WithCancel(ctx)
cancel2()
err = conn.QueryRow(ctx2, "select 1 / $1::int", 1).Scan(&n)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
err = conn.QueryRow(ctx, "select 1 / $1::int", 1).Scan(&n)
require.NoError(t, err)
require.EqualValues(t, 1, n)
})
}
// https://github.com/jackc/pgx/issues/1847
func TestConnDeallocateInvalidatedCachedStatementsInTransactionWithBatch(t *testing.T) {
t.Parallel()
ctx := context.Background()
connString := os.Getenv("PGX_TEST_DATABASE")
config := mustParseConfig(t, connString)
config.DefaultQueryExecMode = pgx.QueryExecModeCacheStatement
config.StatementCacheCapacity = 2
conn, err := pgx.ConnectConfig(ctx, config)
require.NoError(t, err)
tx, err := conn.Begin(ctx)
require.NoError(t, err)
defer tx.Rollback(ctx)
_, err = tx.Exec(ctx, "select $1::int + 1", 1)
require.NoError(t, err)
_, err = tx.Exec(ctx, "select $1::int + 2", 1)
require.NoError(t, err)
// This should invalidate the first cached statement.
_, err = tx.Exec(ctx, "select $1::int + 3", 1)
require.NoError(t, err)
batch := &pgx.Batch{}
batch.Queue("select $1::int + 1", 1)
err = tx.SendBatch(ctx, batch).Close()
require.NoError(t, err)
err = tx.Rollback(ctx)
require.NoError(t, err)
ensureConnValid(t, conn)
}
func TestErrNoRows(t *testing.T) {
t.Parallel()
// ensure we preserve old error message
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())
require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
}

View File

@ -64,6 +64,33 @@ func (cts *copyFromSlice) Err() error {
return cts.err return cts.err
} }
// CopyFromFunc returns a CopyFromSource interface that relies on nxtf for values.
// nxtf returns rows until it either signals an 'end of data' by returning row=nil and err=nil,
// or it returns an error. If nxtf returns an error, the copy is aborted.
func CopyFromFunc(nxtf func() (row []any, err error)) CopyFromSource {
return &copyFromFunc{next: nxtf}
}
type copyFromFunc struct {
next func() ([]any, error)
valueRow []any
err error
}
func (g *copyFromFunc) Next() bool {
g.valueRow, g.err = g.next()
// only return true if valueRow exists and no error
return g.valueRow != nil && g.err == nil
}
func (g *copyFromFunc) Values() ([]any, error) {
return g.valueRow, g.err
}
func (g *copyFromFunc) Err() error {
return g.err
}
// CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data. // CopyFromSource is the interface used by *Conn.CopyFrom as the source for copy data.
type CopyFromSource interface { type CopyFromSource interface {
// Next returns true if there is another row and makes the next row data // Next returns true if there is another row and makes the next row data
@ -85,6 +112,7 @@ type copyFrom struct {
columnNames []string columnNames []string
rowSrc CopyFromSource rowSrc CopyFromSource
readerErrChan chan error readerErrChan chan error
mode QueryExecMode
} }
func (ct *copyFrom) run(ctx context.Context) (int64, error) { func (ct *copyFrom) run(ctx context.Context) (int64, error) {
@ -105,9 +133,29 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
} }
quotedColumnNames := cbuf.String() quotedColumnNames := cbuf.String()
sd, err := ct.conn.Prepare(ctx, "", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName)) var sd *pgconn.StatementDescription
switch ct.mode {
case QueryExecModeExec, QueryExecModeSimpleProtocol:
// These modes don't support the binary format. Before the inclusion of the
// QueryExecModes, Conn.Prepare was called on every COPY operation to get
// the OIDs. These prepared statements were not cached.
//
// Since that's the same behavior provided by QueryExecModeDescribeExec,
// we'll default to that mode.
ct.mode = QueryExecModeDescribeExec
fallthrough
case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec:
var err error
sd, err = ct.conn.getStatementDescription(
ctx,
ct.mode,
fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName),
)
if err != nil { if err != nil {
return 0, err return 0, fmt.Errorf("statement description failed: %w", err)
}
default:
return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
} }
r, w := io.Pipe() r, w := io.Pipe()
@ -167,8 +215,13 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
} }
func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) { func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (bool, []byte, error) {
const sendBufSize = 65536 - 5 // The packet has a 5-byte header
lastBufLen := 0
largestRowLen := 0
for ct.rowSrc.Next() { for ct.rowSrc.Next() {
lastBufLen = len(buf)
values, err := ct.rowSrc.Values() values, err := ct.rowSrc.Values()
if err != nil { if err != nil {
return false, nil, err return false, nil, err
@ -185,7 +238,15 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b
} }
} }
if len(buf) > 65536 { rowLen := len(buf) - lastBufLen
if rowLen > largestRowLen {
largestRowLen = rowLen
}
// Try not to overflow size of the buffer PgConn.CopyFrom will be reading into. If that happens then the nature of
// io.Pipe means that the next Read will be short. This can lead to pathological send sizes such as 65531, 13, 65531
// 13, 65531, 13, 65531, 13.
if len(buf) > sendBufSize-largestRowLen {
return true, buf, nil return true, buf, nil
} }
} }
@ -193,12 +254,14 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b
return false, buf, nil return false, buf, nil
} }
// CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. // CopyFrom uses the PostgreSQL copy protocol to perform bulk data insertion. It returns the number of rows copied and
// It returns the number of rows copied and an error. // an error.
// //
// CopyFrom requires all values use the binary format. Almost all types // CopyFrom requires all values use the binary format. A pgtype.Type that supports the binary format must be registered
// implemented by pgx use the binary format by default. Types implementing // for the type of each column. Almost all types implemented by pgx support the binary format.
// Encoder can only be used if they encode to the binary format. //
// Even though enum types appear to be strings they still must be registered to use with CopyFrom. This can be done with
// Conn.LoadType and pgtype.Map.RegisterType.
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
ct := &copyFrom{ ct := &copyFrom{
conn: c, conn: c,
@ -206,6 +269,7 @@ func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames [
columnNames: columnNames, columnNames: columnNames,
rowSrc: rowSrc, rowSrc: rowSrc,
readerErrChan: make(chan error), readerErrChan: make(chan error),
mode: c.config.DefaultQueryExecMode,
} }
return ct.run(ctx) return ct.run(ctx)

View File

@ -14,9 +14,137 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestConnCopyWithAllQueryExecModes(t *testing.T) {
for _, mode := range pgxtest.AllQueryExecModes {
t.Run(mode.String(), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
cfg.DefaultQueryExecMode = mode
conn := mustConnect(t, cfg)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int2,
b int4,
c int8,
d text,
e timestamptz
)`)
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
inputRows := [][]any{
{int16(0), int32(1), int64(2), "abc", tzedTime},
{nil, nil, nil, nil, nil},
}
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e"}, pgx.CopyFromRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
if int(copyCount) != len(inputRows) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err := conn.Query(ctx, "select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]any
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
}
ensureConnValid(t, conn)
})
}
}
func TestConnCopyWithKnownOIDQueryExecModes(t *testing.T) {
for _, mode := range pgxtest.KnownOIDQueryExecModes {
t.Run(mode.String(), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
cfg.DefaultQueryExecMode = mode
conn := mustConnect(t, cfg)
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int2,
b int4,
c int8,
d varchar,
e text,
f date,
g timestamptz
)`)
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
inputRows := [][]any{
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
{nil, nil, nil, nil, nil, nil, nil},
}
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
if int(copyCount) != len(inputRows) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
}
rows, err := conn.Query(ctx, "select * from foo")
if err != nil {
t.Errorf("Unexpected error for Query: %v", err)
}
var outputRows [][]any
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
if rows.Err() != nil {
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
}
if !reflect.DeepEqual(inputRows, outputRows) {
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
}
ensureConnValid(t, conn)
})
}
}
func TestConnCopyFromSmall(t *testing.T) { func TestConnCopyFromSmall(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) defer closeConn(t, conn)
@ -37,7 +165,7 @@ func TestConnCopyFromSmall(t *testing.T) {
{nil, nil, nil, nil, nil, nil, nil}, {nil, nil, nil, nil, nil, nil, nil},
} }
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
if err != nil { if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err) t.Errorf("Unexpected error for CopyFrom: %v", err)
} }
@ -45,7 +173,7 @@ func TestConnCopyFromSmall(t *testing.T) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
} }
rows, err := conn.Query(context.Background(), "select * from foo") rows, err := conn.Query(ctx, "select * from foo")
if err != nil { if err != nil {
t.Errorf("Unexpected error for Query: %v", err) t.Errorf("Unexpected error for Query: %v", err)
} }
@ -73,6 +201,9 @@ func TestConnCopyFromSmall(t *testing.T) {
func TestConnCopyFromSliceSmall(t *testing.T) { func TestConnCopyFromSliceSmall(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) defer closeConn(t, conn)
@ -93,7 +224,7 @@ func TestConnCopyFromSliceSmall(t *testing.T) {
{nil, nil, nil, nil, nil, nil, nil}, {nil, nil, nil, nil, nil, nil, nil},
} }
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"},
pgx.CopyFromSlice(len(inputRows), func(i int) ([]any, error) { pgx.CopyFromSlice(len(inputRows), func(i int) ([]any, error) {
return inputRows[i], nil return inputRows[i], nil
})) }))
@ -104,7 +235,7 @@ func TestConnCopyFromSliceSmall(t *testing.T) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
} }
rows, err := conn.Query(context.Background(), "select * from foo") rows, err := conn.Query(ctx, "select * from foo")
if err != nil { if err != nil {
t.Errorf("Unexpected error for Query: %v", err) t.Errorf("Unexpected error for Query: %v", err)
} }
@ -132,11 +263,12 @@ func TestConnCopyFromSliceSmall(t *testing.T) {
func TestConnCopyFromLarge(t *testing.T) { func TestConnCopyFromLarge(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) defer closeConn(t, conn)
pgxtest.SkipCockroachDB(t, conn, "Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/52722)")
mustExec(t, conn, `create temporary table foo( mustExec(t, conn, `create temporary table foo(
a int2, a int2,
b int4, b int4,
@ -156,7 +288,7 @@ func TestConnCopyFromLarge(t *testing.T) {
inputRows = append(inputRows, []any{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}}) inputRows = append(inputRows, []any{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}})
} }
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows)) copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
if err != nil { if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err) t.Errorf("Unexpected error for CopyFrom: %v", err)
} }
@ -164,7 +296,7 @@ func TestConnCopyFromLarge(t *testing.T) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
} }
rows, err := conn.Query(context.Background(), "select * from foo") rows, err := conn.Query(ctx, "select * from foo")
if err != nil { if err != nil {
t.Errorf("Unexpected error for Query: %v", err) t.Errorf("Unexpected error for Query: %v", err)
} }
@ -192,10 +324,12 @@ func TestConnCopyFromLarge(t *testing.T) {
func TestConnCopyFromEnum(t *testing.T) { func TestConnCopyFromEnum(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) defer closeConn(t, conn)
ctx := context.Background()
tx, err := conn.Begin(ctx) tx, err := conn.Begin(ctx)
require.NoError(t, err) require.NoError(t, err)
defer tx.Rollback(ctx) defer tx.Rollback(ctx)
@ -220,7 +354,7 @@ func TestConnCopyFromEnum(t *testing.T) {
conn.TypeMap().RegisterType(typ) conn.TypeMap().RegisterType(typ)
} }
_, err = tx.Exec(ctx, `create table foo( _, err = tx.Exec(ctx, `create temporary table foo(
a text, a text,
b color, b color,
c fruit, c fruit,
@ -235,11 +369,11 @@ func TestConnCopyFromEnum(t *testing.T) {
{nil, nil, nil, nil, nil, nil}, {nil, nil, nil, nil, nil, nil},
} }
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f"}, pgx.CopyFromRows(inputRows)) copyCount, err := tx.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f"}, pgx.CopyFromRows(inputRows))
require.NoError(t, err) require.NoError(t, err)
require.EqualValues(t, len(inputRows), copyCount) require.EqualValues(t, len(inputRows), copyCount)
rows, err := conn.Query(ctx, "select * from foo") rows, err := tx.Query(ctx, "select * from foo")
require.NoError(t, err) require.NoError(t, err)
var outputRows [][]any var outputRows [][]any
@ -255,12 +389,18 @@ func TestConnCopyFromEnum(t *testing.T) {
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows) t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
} }
err = tx.Rollback(ctx)
require.NoError(t, err)
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
func TestConnCopyFromJSON(t *testing.T) { func TestConnCopyFromJSON(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) defer closeConn(t, conn)
@ -280,7 +420,7 @@ func TestConnCopyFromJSON(t *testing.T) {
{nil, nil}, {nil, nil},
} }
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
if err != nil { if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err) t.Errorf("Unexpected error for CopyFrom: %v", err)
} }
@ -288,7 +428,7 @@ func TestConnCopyFromJSON(t *testing.T) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
} }
rows, err := conn.Query(context.Background(), "select * from foo") rows, err := conn.Query(ctx, "select * from foo")
if err != nil { if err != nil {
t.Errorf("Unexpected error for Query: %v", err) t.Errorf("Unexpected error for Query: %v", err)
} }
@ -338,6 +478,9 @@ func (cfs *clientFailSource) Err() error {
func TestConnCopyFromFailServerSideMidway(t *testing.T) { func TestConnCopyFromFailServerSideMidway(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) defer closeConn(t, conn)
@ -352,7 +495,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
{int32(3), "def"}, {int32(3), "def"},
} }
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows)) copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
if err == nil { if err == nil {
t.Errorf("Expected CopyFrom return error, but it did not") t.Errorf("Expected CopyFrom return error, but it did not")
} }
@ -363,7 +506,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
} }
rows, err := conn.Query(context.Background(), "select * from foo") rows, err := conn.Query(ctx, "select * from foo")
if err != nil { if err != nil {
t.Errorf("Unexpected error for Query: %v", err) t.Errorf("Unexpected error for Query: %v", err)
} }
@ -414,6 +557,9 @@ func (fs *failSource) Err() error {
func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) defer closeConn(t, conn)
@ -425,7 +571,7 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
startTime := time.Now() startTime := time.Now()
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &failSource{}) copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, &failSource{})
if err == nil { if err == nil {
t.Errorf("Expected CopyFrom return error, but it did not") t.Errorf("Expected CopyFrom return error, but it did not")
} }
@ -442,7 +588,7 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
t.Errorf("Failing CopyFrom shouldn't have taken so long: %v", copyTime) t.Errorf("Failing CopyFrom shouldn't have taken so long: %v", copyTime)
} }
rows, err := conn.Query(context.Background(), "select * from foo") rows, err := conn.Query(ctx, "select * from foo")
if err != nil { if err != nil {
t.Errorf("Unexpected error for Query: %v", err) t.Errorf("Unexpected error for Query: %v", err)
} }
@ -491,6 +637,9 @@ func (fs *slowFailRaceSource) Err() error {
func TestConnCopyFromSlowFailRace(t *testing.T) { func TestConnCopyFromSlowFailRace(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) defer closeConn(t, conn)
@ -499,7 +648,7 @@ func TestConnCopyFromSlowFailRace(t *testing.T) {
b bytea not null b bytea not null
)`) )`)
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, &slowFailRaceSource{}) copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a", "b"}, &slowFailRaceSource{})
if err == nil { if err == nil {
t.Errorf("Expected CopyFrom return error, but it did not") t.Errorf("Expected CopyFrom return error, but it did not")
} }
@ -516,6 +665,9 @@ func TestConnCopyFromSlowFailRace(t *testing.T) {
func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) defer closeConn(t, conn)
@ -523,7 +675,7 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
a bytea not null a bytea not null
)`) )`)
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{}) copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{})
if err == nil { if err == nil {
t.Errorf("Expected CopyFrom return error, but it did not") t.Errorf("Expected CopyFrom return error, but it did not")
} }
@ -531,7 +683,7 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
} }
rows, err := conn.Query(context.Background(), "select * from foo") rows, err := conn.Query(ctx, "select * from foo")
if err != nil { if err != nil {
t.Errorf("Unexpected error for Query: %v", err) t.Errorf("Unexpected error for Query: %v", err)
} }
@ -576,6 +728,9 @@ func (cfs *clientFinalErrSource) Err() error {
func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
t.Parallel() t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn) defer closeConn(t, conn)
@ -583,7 +738,7 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
a bytea not null a bytea not null
)`) )`)
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{}) copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{})
if err == nil { if err == nil {
t.Errorf("Expected CopyFrom return error, but it did not") t.Errorf("Expected CopyFrom return error, but it did not")
} }
@ -591,7 +746,7 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
} }
rows, err := conn.Query(context.Background(), "select * from foo") rows, err := conn.Query(ctx, "select * from foo")
if err != nil { if err != nil {
t.Errorf("Unexpected error for Query: %v", err) t.Errorf("Unexpected error for Query: %v", err)
} }
@ -615,3 +770,125 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
func TestConnCopyFromAutomaticStringConversion(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int8
)`)
inputRows := [][]interface{}{
{"42"},
{"7"},
{8},
}
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
require.NoError(t, err)
require.EqualValues(t, len(inputRows), copyCount)
rows, _ := conn.Query(ctx, "select * from foo")
nums, err := pgx.CollectRows(rows, pgx.RowTo[int64])
require.NoError(t, err)
require.Equal(t, []int64{42, 7, 8}, nums)
ensureConnValid(t, conn)
}
// https://github.com/jackc/pgx/discussions/1891
func TestConnCopyFromAutomaticStringConversionArray(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a numeric[]
)`)
inputRows := [][]interface{}{
{[]string{"42"}},
{[]string{"7"}},
{[]string{"8", "9"}},
{[][]string{{"10", "11"}, {"12", "13"}}},
}
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
require.NoError(t, err)
require.EqualValues(t, len(inputRows), copyCount)
// Test reads as int64 and flattened array for simplicity.
rows, _ := conn.Query(ctx, "select * from foo")
nums, err := pgx.CollectRows(rows, pgx.RowTo[[]int64])
require.NoError(t, err)
require.Equal(t, [][]int64{{42}, {7}, {8, 9}, {10, 11, 12, 13}}, nums)
ensureConnValid(t, conn)
}
func TestCopyFromFunc(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int
)`)
dataCh := make(chan int, 1)
const channelItems = 10
go func() {
for i := 0; i < channelItems; i++ {
dataCh <- i
}
close(dataCh)
}()
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"},
pgx.CopyFromFunc(func() ([]any, error) {
v, ok := <-dataCh
if !ok {
return nil, nil
}
return []any{v}, nil
}))
require.ErrorIs(t, err, nil)
require.EqualValues(t, channelItems, copyCount)
rows, err := conn.Query(context.Background(), "select * from foo order by a")
require.NoError(t, err)
nums, err := pgx.CollectRows(rows, pgx.RowTo[int64])
require.NoError(t, err)
require.Equal(t, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, nums)
// simulate a failure
copyCount, err = conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"},
pgx.CopyFromFunc(func() func() ([]any, error) {
x := 9
return func() ([]any, error) {
x++
if x > 100 {
return nil, fmt.Errorf("simulated error")
}
return []any{x}, nil
}
}()))
require.NotErrorIs(t, err, nil)
require.EqualValues(t, 0, copyCount) // no change, due to error
ensureConnValid(t, conn)
}

256
derived_types.go Normal file
View File

@ -0,0 +1,256 @@
package pgx
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
"github.com/jackc/pgx/v5/pgtype"
)
/*
buildLoadDerivedTypesSQL generates the correct query for retrieving type information.
pgVersion: the major version of the PostgreSQL server
typeNames: the names of the types to load. If nil, load all types.
*/
func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string {
supportsMultirange := (pgVersion >= 14)
var typeNamesClause string
if typeNames == nil {
// This should not occur; this will not return any types
typeNamesClause = "= ''"
} else {
typeNamesClause = "= ANY($1)"
}
parts := make([]string, 0, 10)
// Each of the type names provided might be found in pg_class or pg_type.
// Additionally, it may or may not include a schema portion.
parts = append(parts, `
WITH RECURSIVE
-- find the OIDs in pg_class which match one of the provided type names
selected_classes(oid,reltype) AS (
-- this query uses the namespace search path, so will match type names without a schema prefix
SELECT pg_class.oid, pg_class.reltype
FROM pg_catalog.pg_class
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace
WHERE pg_catalog.pg_table_is_visible(pg_class.oid)
AND relname `, typeNamesClause, `
UNION ALL
-- this query will only match type names which include the schema prefix
SELECT pg_class.oid, pg_class.reltype
FROM pg_class
INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid)
WHERE nspname || '.' || relname `, typeNamesClause, `
),
selected_types(oid) AS (
-- collect the OIDs from pg_types which correspond to the selected classes
SELECT reltype AS oid
FROM selected_classes
UNION ALL
-- as well as any other type names which match our criteria
SELECT pg_type.oid
FROM pg_type
LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid)
WHERE typname `, typeNamesClause, `
OR nspname || '.' || typname `, typeNamesClause, `
),
-- this builds a parent/child mapping of objects, allowing us to know
-- all the child (ie: dependent) types that a parent (type) requires
-- As can be seen, there are 3 ways this can occur (the last of which
-- is due to being a composite class, where the composite fields are children)
pc(parent, child) AS (
SELECT parent.oid, parent.typelem
FROM pg_type parent
WHERE parent.typtype = 'b' AND parent.typelem != 0
UNION ALL
SELECT parent.oid, parent.typbasetype
FROM pg_type parent
WHERE parent.typtypmod = -1 AND parent.typbasetype != 0
UNION ALL
SELECT pg_type.oid, atttypid
FROM pg_attribute
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
WHERE NOT attisdropped
AND attnum > 0
),
-- Now construct a recursive query which includes a 'depth' element.
-- This is used to ensure that the "youngest" children are registered before
-- their parents.
relationships(parent, child, depth) AS (
SELECT DISTINCT 0::OID, selected_types.oid, 0
FROM selected_types
UNION ALL
SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1
FROM selected_classes c
inner join pg_type ON (c.reltype = pg_type.oid)
inner join pg_attribute on (c.oid = pg_attribute.attrelid)
UNION ALL
SELECT pc.parent, pc.child, relationships.depth + 1
FROM pc
INNER JOIN relationships ON (pc.parent = relationships.child)
),
-- composite fields need to be encapsulated as a couple of arrays to provide the required information for registration
composite AS (
SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids
FROM pg_attribute
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
WHERE NOT attisdropped
AND attnum > 0
GROUP BY pg_type.oid
)
-- Bring together this information, showing all the information which might possibly be required
-- to complete the registration, applying filters to only show the items which relate to the selected
-- types/classes.
SELECT typname,
pg_namespace.nspname,
typtype,
typbasetype,
typelem,
pg_type.oid,`)
if supportsMultirange {
parts = append(parts, `
COALESCE(multirange.rngtypid, 0) AS rngtypid,`)
} else {
parts = append(parts, `
0 AS rngtypid,`)
}
parts = append(parts, `
COALESCE(pg_range.rngsubtype, 0) AS rngsubtype,
attnames, atttypids
FROM relationships
INNER JOIN pg_type ON (pg_type.oid = relationships.child)
LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`)
if supportsMultirange {
parts = append(parts, `
LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`)
}
parts = append(parts, `
LEFT OUTER JOIN composite USING (oid)
LEFT OUTER JOIN pg_namespace ON (pg_type.typnamespace = pg_namespace.oid)
WHERE NOT (typtype = 'b' AND typelem = 0)`)
parts = append(parts, `
GROUP BY typname, pg_namespace.nspname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`)
if supportsMultirange {
parts = append(parts, `
multirange.rngtypid,`)
}
parts = append(parts, `
attnames, atttypids
ORDER BY MAX(depth) desc, typname;`)
return strings.Join(parts, "")
}
type derivedTypeInfo struct {
Oid, Typbasetype, Typelem, Rngsubtype, Rngtypid uint32
TypeName, Typtype, NspName string
Attnames []string
Atttypids []uint32
}
// LoadTypes performs a single (complex) query, returning all the required
// information to register the named types, as well as any other types directly
// or indirectly required to complete the registration.
// The result of this call can be passed into RegisterTypes to complete the process.
func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) {
m := c.TypeMap()
if len(typeNames) == 0 {
return nil, fmt.Errorf("No type names were supplied.")
}
// Disregard server version errors. This will result in
// the SQL not support recent structures such as multirange
serverVersion, _ := serverVersion(c)
sql := buildLoadDerivedTypesSQL(serverVersion, typeNames)
rows, err := c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
if err != nil {
return nil, fmt.Errorf("While generating load types query: %w", err)
}
defer rows.Close()
result := make([]*pgtype.Type, 0, 100)
for rows.Next() {
ti := derivedTypeInfo{}
err = rows.Scan(&ti.TypeName, &ti.NspName, &ti.Typtype, &ti.Typbasetype, &ti.Typelem, &ti.Oid, &ti.Rngtypid, &ti.Rngsubtype, &ti.Attnames, &ti.Atttypids)
if err != nil {
return nil, fmt.Errorf("While scanning type information: %w", err)
}
var type_ *pgtype.Type
switch ti.Typtype {
case "b": // array
dt, ok := m.TypeForOID(ti.Typelem)
if !ok {
return nil, fmt.Errorf("Array element OID %v not registered while loading pgtype %q", ti.Typelem, ti.TypeName)
}
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}
case "c": // composite
var fields []pgtype.CompositeCodecField
for i, fieldName := range ti.Attnames {
dt, ok := m.TypeForOID(ti.Atttypids[i])
if !ok {
return nil, fmt.Errorf("Unknown field for composite type %q: field %q (OID %v) is not already registered.", ti.TypeName, fieldName, ti.Atttypids[i])
}
fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt})
}
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.CompositeCodec{Fields: fields}}
case "d": // domain
dt, ok := m.TypeForOID(ti.Typbasetype)
if !ok {
return nil, fmt.Errorf("Domain base type OID %v was not already registered, needed for %q", ti.Typbasetype, ti.TypeName)
}
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: dt.Codec}
case "e": // enum
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.EnumCodec{}}
case "r": // range
dt, ok := m.TypeForOID(ti.Rngsubtype)
if !ok {
return nil, fmt.Errorf("Range element OID %v was not already registered, needed for %q", ti.Rngsubtype, ti.TypeName)
}
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.RangeCodec{ElementType: dt}}
case "m": // multirange
dt, ok := m.TypeForOID(ti.Rngtypid)
if !ok {
return nil, fmt.Errorf("Multirange element OID %v was not already registered, needed for %q", ti.Rngtypid, ti.TypeName)
}
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}
default:
return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName)
}
// the type_ is imposible to be null
m.RegisterType(type_)
if ti.NspName != "" {
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}
m.RegisterType(nspType)
result = append(result, nspType)
}
result = append(result, type_)
}
return result, nil
}
// serverVersion returns the postgresql server version.
func serverVersion(c *Conn) (int64, error) {
serverVersionStr := c.PgConn().ParameterStatus("server_version")
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
// if not PostgreSQL do nothing
if serverVersionStr == "" {
return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr)
}
version, err := strconv.ParseInt(serverVersionStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("postgres version parsing failed: %w", err)
}
return version, nil
}

40
derived_types_test.go Normal file
View File

@ -0,0 +1,40 @@
package pgx_test
import (
"context"
"testing"
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/require"
)
func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `
drop type if exists dtype_test;
drop domain if exists anotheruint64;
create domain anotheruint64 as numeric(20,0);
create type dtype_test as (
a text,
b int4,
c anotheruint64,
d anotheruint64[]
);`)
require.NoError(t, err)
defer conn.Exec(ctx, "drop type dtype_test")
defer conn.Exec(ctx, "drop domain anotheruint64")
types, err := conn.LoadTypes(ctx, []string{"dtype_test"})
require.NoError(t, err)
require.Len(t, types, 6)
require.Equal(t, types[0].Name, "public.anotheruint64")
require.Equal(t, types[1].Name, "anotheruint64")
require.Equal(t, types[2].Name, "public._anotheruint64")
require.Equal(t, types[3].Name, "_anotheruint64")
require.Equal(t, types[4].Name, "public.dtype_test")
require.Equal(t, types[5].Name, "dtype_test")
})
}

39
doc.go
View File

@ -7,24 +7,25 @@ details.
Establishing a Connection Establishing a Connection
The primary way of establishing a connection is with `pgx.Connect`. The primary way of establishing a connection is with [pgx.Connect]:
conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL")) conn, err := pgx.Connect(context.Background(), os.Getenv("DATABASE_URL"))
The database connection string can be in URL or DSN format. Both PostgreSQL settings and pgx settings can be specified The database connection string can be in URL or key/value format. Both PostgreSQL settings and pgx settings can be
here. In addition, a config struct can be created by `ParseConfig` and modified before establishing the connection with specified here. In addition, a config struct can be created by [ParseConfig] and modified before establishing the
`ConnectConfig` to configure settings such as tracing that cannot be configured with a connection string. connection with [ConnectConfig] to configure settings such as tracing that cannot be configured with a connection
string.
Connection Pool Connection Pool
`*pgx.Conn` represents a single connection to the database and is not concurrency safe. Use package [*pgx.Conn] represents a single connection to the database and is not concurrency safe. Use package
github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool. github.com/jackc/pgx/v5/pgxpool for a concurrency safe connection pool.
Query Interface Query Interface
pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and pgx implements Query in the familiar database/sql style. However, pgx provides generic functions such as CollectRows and
ForEachRow that are a simpler and safer way of processing rows than manually calling rows.Next(), rows.Scan, and ForEachRow that are a simpler and safer way of processing rows than manually calling defer rows.Close(), rows.Next(),
rows.Err(). rows.Scan, and rows.Err().
CollectRows can be used collect all returned rows into a slice. CollectRows can be used collect all returned rows into a slice.
@ -40,7 +41,7 @@ directly.
var sum, n int32 var sum, n int32
rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 10) rows, _ := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
_, err := pgx.ForEachRow(rows, []any{&n}, func(pgx.QueryFuncRow) error { _, err := pgx.ForEachRow(rows, []any{&n}, func() error {
sum += n sum += n
return nil return nil
}) })
@ -69,8 +70,9 @@ Use Exec to execute a query that does not return a result set.
PostgreSQL Data Types PostgreSQL Data Types
The package pgtype provides extensive and customizable support for converting Go values to and from PostgreSQL values pgx uses the pgtype package to converting Go values to and from PostgreSQL values. It supports many PostgreSQL types
including array and composite types. See that package's documentation for details. directly and is customizable and extendable. User defined data types such as enums, domains, and composite types may
require type registration. See that package's documentation for details.
Transactions Transactions
@ -97,7 +99,8 @@ Transactions are started by calling Begin.
The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions. The Tx returned from Begin also implements the Begin method. This can be used to implement pseudo nested transactions.
These are internally implemented with savepoints. These are internally implemented with savepoints.
Use BeginTx to control the transaction mode. Use BeginTx to control the transaction mode. BeginTx also can be used to ensure a new transaction is created instead of
a pseudo nested transaction.
BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the BeginFunc and BeginTxFunc are functions that begin a transaction, execute a function, and commit or rollback the
transaction depending on the return value of the function. These can be simpler and less error prone to use. transaction depending on the return value of the function. These can be simpler and less error prone to use.
@ -160,17 +163,19 @@ notification is received or the context is canceled.
_, err := conn.Exec(context.Background(), "listen channelname") _, err := conn.Exec(context.Background(), "listen channelname")
if err != nil { if err != nil {
return nil return err
} }
if notification, err := conn.WaitForNotification(context.Background()); err != nil { notification, err := conn.WaitForNotification(context.Background())
// do something with notification if err != nil {
return err
} }
// do something with notification
Tracing and Logging Tracing and Logging
pgx supports tracing by setting ConnConfig.Tracer. pgx supports tracing by setting ConnConfig.Tracer. To combine several tracers you can use the multitracer.Tracer.
In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer. In addition, the tracelog package provides the TraceLog type which lets a traditional logger act as a Tracer.
@ -178,12 +183,12 @@ For debug tracing of the actual PostgreSQL wire protocol messages see github.com
Lower Level PostgreSQL Functionality Lower Level PostgreSQL Functionality
github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn in github.com/jackc/pgx/v5/pgconn contains a lower level PostgreSQL driver roughly at the level of libpq. pgx.Conn is
implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer. implemented on top of pgconn. The Conn.PgConn() method can be used to access this lower layer.
PgBouncer PgBouncer
By default pgx automatically uses prepared statements. Prepared statements are incompaptible with PgBouncer. This can be By default pgx automatically uses prepared statements. Prepared statements are incompatible with PgBouncer. This can be
disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode. disabled by setting a different QueryExecMode in ConnConfig.DefaultQueryExecMode.
*/ */
package pgx package pgx

View File

@ -2,7 +2,7 @@ package main
import ( import (
"context" "context"
"io/ioutil" "io"
"log" "log"
"net/http" "net/http"
"os" "os"
@ -29,7 +29,7 @@ func getUrlHandler(w http.ResponseWriter, req *http.Request) {
func putUrlHandler(w http.ResponseWriter, req *http.Request) { func putUrlHandler(w http.ResponseWriter, req *http.Request) {
id := req.URL.Path id := req.URL.Path
var url string var url string
if body, err := ioutil.ReadAll(req.Body); err == nil { if body, err := io.ReadAll(req.Body); err == nil {
url = string(body) url = string(body)
} else { } else {
http.Error(w, "Internal server error", http.StatusInternalServerError) http.Error(w, "Internal server error", http.StatusInternalServerError)

View File

@ -3,7 +3,6 @@ package pgx
import ( import (
"fmt" "fmt"
"github.com/jackc/pgx/v5/internal/anynil"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgtype"
) )
@ -22,10 +21,15 @@ type ExtendedQueryBuilder struct {
func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error { func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescription, args []any) error {
eqb.reset() eqb.reset()
anynil.NormalizeSlice(args)
if sd == nil { if sd == nil {
return eqb.appendParamsForQueryExecModeExec(m, args) for i := range args {
err := eqb.appendParam(m, 0, pgtype.TextFormatCode, args[i])
if err != nil {
err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
return err
}
}
return nil
} }
if len(sd.ParamOIDs) != len(args) { if len(sd.ParamOIDs) != len(args) {
@ -35,7 +39,7 @@ func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescri
for i := range args { for i := range args {
err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i]) err := eqb.appendParam(m, sd.ParamOIDs[i], -1, args[i])
if err != nil { if err != nil {
err = fmt.Errorf("failed to encode args[%d]: %v", i, err) err = fmt.Errorf("failed to encode args[%d]: %w", i, err)
return err return err
} }
} }
@ -51,14 +55,33 @@ func (eqb *ExtendedQueryBuilder) Build(m *pgtype.Map, sd *pgconn.StatementDescri
// must be an untyped nil. // must be an untyped nil.
func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error { func (eqb *ExtendedQueryBuilder) appendParam(m *pgtype.Map, oid uint32, format int16, arg any) error {
if format == -1 { if format == -1 {
format = eqb.chooseParameterFormatCode(m, oid, arg) preferredFormat := eqb.chooseParameterFormatCode(m, oid, arg)
preferredErr := eqb.appendParam(m, oid, preferredFormat, arg)
if preferredErr == nil {
return nil
}
var otherFormat int16
if preferredFormat == TextFormatCode {
otherFormat = BinaryFormatCode
} else {
otherFormat = TextFormatCode
}
otherErr := eqb.appendParam(m, oid, otherFormat, arg)
if otherErr == nil {
return nil
}
return preferredErr // return the error from the preferred format
} }
eqb.ParamFormats = append(eqb.ParamFormats, format)
v, err := eqb.encodeExtendedParamValue(m, oid, format, arg) v, err := eqb.encodeExtendedParamValue(m, oid, format, arg)
if err != nil { if err != nil {
return err return err
} }
eqb.ParamFormats = append(eqb.ParamFormats, format)
eqb.ParamValues = append(eqb.ParamValues, v) eqb.ParamValues = append(eqb.ParamValues, v)
return nil return nil
@ -93,10 +116,6 @@ func (eqb *ExtendedQueryBuilder) reset() {
} }
func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) { func (eqb *ExtendedQueryBuilder) encodeExtendedParamValue(m *pgtype.Map, oid uint32, formatCode int16, arg any) ([]byte, error) {
if anynil.Is(arg) {
return nil, nil
}
if eqb.paramValueBytes == nil { if eqb.paramValueBytes == nil {
eqb.paramValueBytes = make([]byte, 0, 128) eqb.paramValueBytes = make([]byte, 0, 128)
} }
@ -125,61 +144,3 @@ func (eqb *ExtendedQueryBuilder) chooseParameterFormatCode(m *pgtype.Map, oid ui
return m.FormatCodeForOID(oid) return m.FormatCodeForOID(oid)
} }
// appendParamsForQueryExecModeExec appends the args to eqb.
//
// Parameters must be encoded in the text format because of differences in type conversion between timestamps and
// dates. In QueryExecModeExec we don't know what the actual PostgreSQL type is. To determine the type we use the
// Go type to OID type mapping registered by RegisterDefaultPgType. However, the Go time.Time represents both
// PostgreSQL timestamp[tz] and date. To use the binary format we would need to also specify what the PostgreSQL
// type OID is. But that would mean telling PostgreSQL that we have sent a timestamp[tz] when what is needed is a date.
// This means that the value is converted from text to timestamp[tz] to date. This means it does a time zone conversion
// before converting it to date. This means that dates can be shifted by one day. In text format without that double
// type conversion it takes the date directly and ignores time zone (i.e. it works).
//
// Given that the whole point of QueryExecModeExec is to operate without having to know the PostgreSQL types there is
// no way to safely use binary or to specify the parameter OIDs.
func (eqb *ExtendedQueryBuilder) appendParamsForQueryExecModeExec(m *pgtype.Map, args []any) error {
for _, arg := range args {
if arg == nil {
err := eqb.appendParam(m, 0, TextFormatCode, arg)
if err != nil {
return err
}
} else {
dt, ok := m.TypeForValue(arg)
if !ok {
var tv pgtype.TextValuer
if tv, ok = arg.(pgtype.TextValuer); ok {
t, err := tv.TextValue()
if err != nil {
return err
}
dt, ok = m.TypeForOID(pgtype.TextOID)
if ok {
arg = t
}
}
}
if !ok {
var str fmt.Stringer
if str, ok = arg.(fmt.Stringer); ok {
dt, ok = m.TypeForOID(pgtype.TextOID)
if ok {
arg = str.String()
}
}
}
if !ok {
return &unknownArgumentTypeQueryExecModeExecError{arg: arg}
}
err := eqb.appendParam(m, dt.OID, TextFormatCode, arg)
if err != nil {
return err
}
}
}
return nil
}

17
go.mod
View File

@ -1,20 +1,21 @@
module github.com/jackc/pgx/v5 module github.com/jackc/pgx/v5
go 1.18 go 1.23.0
require ( require (
github.com/jackc/pgpassfile v1.0.0 github.com/jackc/pgpassfile v1.0.0
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761
github.com/jackc/puddle/v2 v2.0.0-beta.1 github.com/jackc/puddle/v2 v2.2.2
github.com/stretchr/testify v1.8.0 github.com/stretchr/testify v1.8.1
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa golang.org/x/crypto v0.37.0
golang.org/x/text v0.3.7 golang.org/x/sync v0.13.0
golang.org/x/text v0.24.0
) )
require ( require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/kr/pretty v0.1.0 // indirect github.com/kr/pretty v0.3.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

41
go.sum
View File

@ -1,36 +1,45 @@
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b h1:C8S2+VttkHFdOOCXJe+YGfa4vHYwlt4Zx+IVXQ97jYg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/puddle/v2 v2.0.0-beta.1 h1:Y4Ao+kFWANtDhWUkdw1JcbH+x84/aq6WUfhVQ1wdib8= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.0.0-beta.1/go.mod h1:itE7ZJY8xnoo0JqJEpSMprN0f+NQkMCuEV/N9j8h0oc= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b h1:QAqMVf3pSa6eeTsuklijukjXBlj7Es2QQplab+/RbQ4= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
golang.org/x/crypto v0.0.0-20211209193657-4570a0811e8b/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0=
golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -79,7 +79,7 @@ func ensureConnValid(t testing.TB, conn *pgx.Conn) {
} }
if rows.Err() != nil { if rows.Err() != nil {
t.Fatalf("conn.Query failed: %v", err) t.Fatalf("conn.Query failed: %v", rows.Err())
} }
if rowCount != 10 { if rowCount != 10 {

View File

@ -1,36 +0,0 @@
package anynil
import "reflect"
// Is returns true if value is any type of nil. e.g. nil or []byte(nil).
func Is(value any) bool {
if value == nil {
return true
}
refVal := reflect.ValueOf(value)
switch refVal.Kind() {
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.UnsafePointer, reflect.Interface, reflect.Slice:
return refVal.IsNil()
default:
return false
}
}
// Normalize converts typed nils (e.g. []byte(nil)) into untyped nil. Other values are returned unmodified.
func Normalize(v any) any {
if Is(v) {
return nil
}
return v
}
// NormalizeSlice converts all typed nils (e.g. []byte(nil)) in s into untyped nils. Other values are unmodified. s is
// mutated in place.
func NormalizeSlice(s []any) {
for i := range s {
if Is(s[i]) {
s[i] = nil
}
}
}

View File

@ -1,4 +1,7 @@
// Package iobufpool implements a global segregated-fit pool of buffers for IO. // Package iobufpool implements a global segregated-fit pool of buffers for IO.
//
// It uses *[]byte instead of []byte to avoid the sync.Pool allocation with Put. Unfortunately, using a pointer to avoid
// an allocation is purposely not documented. https://github.com/golang/go/issues/16323
package iobufpool package iobufpool
import "sync" import "sync"
@ -10,17 +13,27 @@ var pools [18]*sync.Pool
func init() { func init() {
for i := range pools { for i := range pools {
bufLen := 1 << (minPoolExpOf2 + i) bufLen := 1 << (minPoolExpOf2 + i)
pools[i] = &sync.Pool{New: func() any { return make([]byte, bufLen) }} pools[i] = &sync.Pool{
New: func() any {
buf := make([]byte, bufLen)
return &buf
},
}
} }
} }
// Get gets a []byte of len size with cap <= size*2. // Get gets a []byte of len size with cap <= size*2.
func Get(size int) []byte { func Get(size int) *[]byte {
i := getPoolIdx(size) i := getPoolIdx(size)
if i >= len(pools) { if i >= len(pools) {
return make([]byte, size) buf := make([]byte, size)
return &buf
} }
return pools[i].Get().([]byte)[:size]
ptrBuf := (pools[i].Get().(*[]byte))
*ptrBuf = (*ptrBuf)[:size]
return ptrBuf
} }
func getPoolIdx(size int) int { func getPoolIdx(size int) int {
@ -36,8 +49,8 @@ func getPoolIdx(size int) int {
} }
// Put returns buf to the pool. // Put returns buf to the pool.
func Put(buf []byte) { func Put(buf *[]byte) {
i := putPoolIdx(cap(buf)) i := putPoolIdx(cap(*buf))
if i < 0 { if i < 0 {
return return
} }

View File

@ -30,15 +30,15 @@ func TestGetCap(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
buf := iobufpool.Get(tt.requestedLen) buf := iobufpool.Get(tt.requestedLen)
assert.Equalf(t, tt.requestedLen, len(buf), "bad len for requestedLen: %d", len(buf), tt.requestedLen) assert.Equalf(t, tt.requestedLen, len(*buf), "bad len for requestedLen: %d", len(*buf), tt.requestedLen)
assert.Equalf(t, tt.expectedCap, cap(buf), "bad cap for requestedLen: %d", tt.requestedLen) assert.Equalf(t, tt.expectedCap, cap(*buf), "bad cap for requestedLen: %d", tt.requestedLen)
} }
} }
func TestPutHandlesWrongSizedBuffers(t *testing.T) { func TestPutHandlesWrongSizedBuffers(t *testing.T) {
for putBufSize := range []int{0, 1, 128, 250, 256, 257, 1023, 1024, 1025, 1 << 28} { for putBufSize := range []int{0, 1, 128, 250, 256, 257, 1023, 1024, 1025, 1 << 28} {
putBuf := make([]byte, putBufSize) putBuf := make([]byte, putBufSize)
iobufpool.Put(putBuf) iobufpool.Put(&putBuf)
tests := []struct { tests := []struct {
requestedLen int requestedLen int
@ -62,8 +62,8 @@ func TestPutHandlesWrongSizedBuffers(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
getBuf := iobufpool.Get(tt.requestedLen) getBuf := iobufpool.Get(tt.requestedLen)
assert.Equalf(t, tt.requestedLen, len(getBuf), "len(putBuf): %d, requestedLen: %d", len(putBuf), tt.requestedLen) assert.Equalf(t, tt.requestedLen, len(*getBuf), "len(putBuf): %d, requestedLen: %d", len(putBuf), tt.requestedLen)
assert.Equalf(t, tt.expectedCap, cap(getBuf), "cap(putBuf): %d, requestedLen: %d", cap(putBuf), tt.requestedLen) assert.Equalf(t, tt.expectedCap, cap(*getBuf), "cap(putBuf): %d, requestedLen: %d", cap(putBuf), tt.requestedLen)
} }
} }
} }
@ -73,10 +73,10 @@ func TestPutGetBufferReuse(t *testing.T) {
// it not to be. So try many times. // it not to be. So try many times.
for i := 0; i < 100000; i++ { for i := 0; i < 100000; i++ {
buf := iobufpool.Get(4) buf := iobufpool.Get(4)
buf[0] = 1 (*buf)[0] = 1
iobufpool.Put(buf) iobufpool.Put(buf)
buf = iobufpool.Get(4) buf = iobufpool.Get(4)
if buf[0] == 1 { if (*buf)[0] == 1 {
return return
} }
} }

View File

@ -1,70 +0,0 @@
package nbconn
import (
"sync"
)
const minBufferQueueLen = 8
type bufferQueue struct {
lock sync.Mutex
queue [][]byte
r, w int
}
func (bq *bufferQueue) pushBack(buf []byte) {
bq.lock.Lock()
defer bq.lock.Unlock()
if bq.w >= len(bq.queue) {
bq.growQueue()
}
bq.queue[bq.w] = buf
bq.w++
}
func (bq *bufferQueue) pushFront(buf []byte) {
bq.lock.Lock()
defer bq.lock.Unlock()
if bq.w >= len(bq.queue) {
bq.growQueue()
}
copy(bq.queue[bq.r+1:bq.w+1], bq.queue[bq.r:bq.w])
bq.queue[bq.r] = buf
bq.w++
}
func (bq *bufferQueue) popFront() []byte {
bq.lock.Lock()
defer bq.lock.Unlock()
if bq.r == bq.w {
return nil
}
buf := bq.queue[bq.r]
bq.queue[bq.r] = nil // Clear reference so it can be garbage collected.
bq.r++
if bq.r == bq.w {
bq.r = 0
bq.w = 0
if len(bq.queue) > minBufferQueueLen {
bq.queue = make([][]byte, minBufferQueueLen)
}
}
return buf
}
func (bq *bufferQueue) growQueue() {
desiredLen := (len(bq.queue) + 1) * 3 / 2
if desiredLen < minBufferQueueLen {
desiredLen = minBufferQueueLen
}
newQueue := make([][]byte, desiredLen)
copy(newQueue, bq.queue)
bq.queue = newQueue
}

View File

@ -1,536 +0,0 @@
// Package nbconn implements a non-blocking net.Conn wrapper.
//
// It is designed to solve three problems.
//
// The first is resolving the deadlock that can occur when both sides of a connection are blocked writing because all
// buffers between are full. See https://github.com/jackc/pgconn/issues/27 for discussion.
//
// The second is the inability to use a write deadline with a TLS.Conn without killing the connection.
//
// The third is to efficiently check if a connection has been closed via a non-blocking read.
package nbconn
import (
"crypto/tls"
"errors"
"io"
"net"
"os"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/jackc/pgx/v5/internal/iobufpool"
)
var errClosed = errors.New("closed")
var ErrWouldBlock = new(wouldBlockError)
const fakeNonblockingWaitDuration = 100 * time.Millisecond
// NonBlockingDeadline is a magic value that when passed to Set[Read]Deadline places the connection in non-blocking read
// mode.
var NonBlockingDeadline = time.Date(1900, 1, 1, 0, 0, 0, 608536336, time.UTC)
// disableSetDeadlineDeadline is a magic value that when passed to Set[Read|Write]Deadline causes those methods to
// ignore all future calls.
var disableSetDeadlineDeadline = time.Date(1900, 1, 1, 0, 0, 0, 968549727, time.UTC)
// wouldBlockError implements net.Error so tls.Conn will recognize ErrWouldBlock as a temporary error.
type wouldBlockError struct{}
func (*wouldBlockError) Error() string {
return "would block"
}
func (*wouldBlockError) Timeout() bool { return true }
func (*wouldBlockError) Temporary() bool { return true }
// Conn is a net.Conn where Write never blocks and always succeeds. Flush or Read must be called to actually write to
// the underlying connection.
type Conn interface {
net.Conn
// Flush flushes any buffered writes.
Flush() error
// BufferReadUntilBlock reads and buffers any sucessfully read bytes until the read would block.
BufferReadUntilBlock() error
}
// NetConn is a non-blocking net.Conn wrapper. It implements net.Conn.
type NetConn struct {
conn net.Conn
rawConn syscall.RawConn
readQueue bufferQueue
writeQueue bufferQueue
readFlushLock sync.Mutex
// non-blocking writes with syscall.RawConn are done with a callback function. By using these fields instead of the
// callback functions closure to pass the buf argument and receive the n and err results we avoid some allocations.
nonblockWriteBuf []byte
nonblockWriteErr error
nonblockWriteN int
readDeadlineLock sync.Mutex
readDeadline time.Time
readNonblocking bool
writeDeadlineLock sync.Mutex
writeDeadline time.Time
// Only access with atomics
closed int64 // 0 = not closed, 1 = closed
}
func NewNetConn(conn net.Conn, fakeNonBlockingIO bool) *NetConn {
nc := &NetConn{
conn: conn,
}
if !fakeNonBlockingIO {
if sc, ok := conn.(syscall.Conn); ok {
if rawConn, err := sc.SyscallConn(); err == nil {
nc.rawConn = rawConn
}
}
}
return nc
}
// Read implements io.Reader.
func (c *NetConn) Read(b []byte) (n int, err error) {
if c.isClosed() {
return 0, errClosed
}
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
err = c.flush()
if err != nil {
return 0, err
}
for n < len(b) {
buf := c.readQueue.popFront()
if buf == nil {
break
}
copiedN := copy(b[n:], buf)
if copiedN < len(buf) {
buf = buf[copiedN:]
c.readQueue.pushFront(buf)
} else {
iobufpool.Put(buf)
}
n += copiedN
}
// If any bytes were already buffered return them without trying to do a Read. Otherwise, when the caller is trying to
// Read up to len(b) bytes but all available bytes have already been buffered the underlying Read would block.
if n > 0 {
return n, nil
}
var readNonblocking bool
c.readDeadlineLock.Lock()
readNonblocking = c.readNonblocking
c.readDeadlineLock.Unlock()
var readN int
if readNonblocking {
readN, err = c.nonblockingRead(b[n:])
} else {
readN, err = c.conn.Read(b[n:])
}
n += readN
return n, err
}
// Write implements io.Writer. It never blocks due to buffering all writes. It will only return an error if the Conn is
// closed. Call Flush to actually write to the underlying connection.
func (c *NetConn) Write(b []byte) (n int, err error) {
if c.isClosed() {
return 0, errClosed
}
buf := iobufpool.Get(len(b))
copy(buf, b)
c.writeQueue.pushBack(buf)
return len(b), nil
}
func (c *NetConn) Close() (err error) {
swapped := atomic.CompareAndSwapInt64(&c.closed, 0, 1)
if !swapped {
return errClosed
}
defer func() {
closeErr := c.conn.Close()
if err == nil {
err = closeErr
}
}()
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
err = c.flush()
if err != nil {
return err
}
return nil
}
func (c *NetConn) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *NetConn) RemoteAddr() net.Addr {
return c.conn.RemoteAddr()
}
// SetDeadline is the equivalent of calling SetReadDealine(t) and SetWriteDeadline(t).
func (c *NetConn) SetDeadline(t time.Time) error {
err := c.SetReadDeadline(t)
if err != nil {
return err
}
return c.SetWriteDeadline(t)
}
// SetReadDeadline sets the read deadline as t. If t == NonBlockingDeadline then future reads will be non-blocking.
func (c *NetConn) SetReadDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
}
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
if c.readDeadline == disableSetDeadlineDeadline {
return nil
}
if t == disableSetDeadlineDeadline {
c.readDeadline = t
return nil
}
if t == NonBlockingDeadline {
c.readNonblocking = true
t = time.Time{}
} else {
c.readNonblocking = false
}
c.readDeadline = t
return c.conn.SetReadDeadline(t)
}
func (c *NetConn) SetWriteDeadline(t time.Time) error {
if c.isClosed() {
return errClosed
}
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
if c.writeDeadline == disableSetDeadlineDeadline {
return nil
}
if t == disableSetDeadlineDeadline {
c.writeDeadline = t
return nil
}
c.writeDeadline = t
return c.conn.SetWriteDeadline(t)
}
func (c *NetConn) Flush() error {
if c.isClosed() {
return errClosed
}
c.readFlushLock.Lock()
defer c.readFlushLock.Unlock()
return c.flush()
}
// flush does the actual work of flushing the writeQueue. readFlushLock must already be held.
func (c *NetConn) flush() error {
var stopChan chan struct{}
var errChan chan error
defer func() {
if stopChan != nil {
select {
case stopChan <- struct{}{}:
case <-errChan:
}
}
}()
for buf := c.writeQueue.popFront(); buf != nil; buf = c.writeQueue.popFront() {
remainingBuf := buf
for len(remainingBuf) > 0 {
n, err := c.nonblockingWrite(remainingBuf)
remainingBuf = remainingBuf[n:]
if err != nil {
if !errors.Is(err, ErrWouldBlock) {
buf = buf[:len(remainingBuf)]
copy(buf, remainingBuf)
c.writeQueue.pushFront(buf)
return err
}
// Writing was blocked. Reading might unblock it.
if stopChan == nil {
stopChan, errChan = c.bufferNonblockingRead()
}
select {
case err := <-errChan:
stopChan = nil
return err
default:
}
}
}
iobufpool.Put(buf)
}
return nil
}
func (c *NetConn) BufferReadUntilBlock() error {
for {
buf := iobufpool.Get(8 * 1024)
n, err := c.nonblockingRead(buf)
if n > 0 {
buf = buf[:n]
c.readQueue.pushBack(buf)
}
if err != nil {
if errors.Is(err, ErrWouldBlock) {
return nil
} else {
return err
}
}
}
}
func (c *NetConn) bufferNonblockingRead() (stopChan chan struct{}, errChan chan error) {
stopChan = make(chan struct{})
errChan = make(chan error, 1)
go func() {
for {
err := c.BufferReadUntilBlock()
if err != nil {
errChan <- err
return
}
select {
case <-stopChan:
return
default:
}
}
}()
return stopChan, errChan
}
func (c *NetConn) isClosed() bool {
closed := atomic.LoadInt64(&c.closed)
return closed == 1
}
func (c *NetConn) nonblockingWrite(b []byte) (n int, err error) {
if c.rawConn == nil {
return c.fakeNonblockingWrite(b)
} else {
return c.realNonblockingWrite(b)
}
}
func (c *NetConn) fakeNonblockingWrite(b []byte) (n int, err error) {
c.writeDeadlineLock.Lock()
defer c.writeDeadlineLock.Unlock()
deadline := time.Now().Add(fakeNonblockingWaitDuration)
if c.writeDeadline.IsZero() || deadline.Before(c.writeDeadline) {
err = c.conn.SetWriteDeadline(deadline)
if err != nil {
return 0, err
}
defer func() {
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
c.conn.SetWriteDeadline(c.writeDeadline)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ErrWouldBlock
}
}
}()
}
return c.conn.Write(b)
}
// realNonblockingWrite does a non-blocking write. readFlushLock must already be held.
func (c *NetConn) realNonblockingWrite(b []byte) (n int, err error) {
c.nonblockWriteBuf = b
c.nonblockWriteN = 0
c.nonblockWriteErr = nil
err = c.rawConn.Write(func(fd uintptr) (done bool) {
c.nonblockWriteN, c.nonblockWriteErr = syscall.Write(int(fd), c.nonblockWriteBuf)
return true
})
n = c.nonblockWriteN
if err == nil && c.nonblockWriteErr != nil {
if errors.Is(c.nonblockWriteErr, syscall.EWOULDBLOCK) {
err = ErrWouldBlock
} else {
err = c.nonblockWriteErr
}
}
if err != nil {
// n may be -1 when an error occurs.
if n < 0 {
n = 0
}
return n, err
}
return n, nil
}
func (c *NetConn) nonblockingRead(b []byte) (n int, err error) {
if c.rawConn == nil {
return c.fakeNonblockingRead(b)
} else {
return c.realNonblockingRead(b)
}
}
func (c *NetConn) fakeNonblockingRead(b []byte) (n int, err error) {
c.readDeadlineLock.Lock()
defer c.readDeadlineLock.Unlock()
deadline := time.Now().Add(fakeNonblockingWaitDuration)
if c.readDeadline.IsZero() || deadline.Before(c.readDeadline) {
err = c.conn.SetReadDeadline(deadline)
if err != nil {
return 0, err
}
defer func() {
// Ignoring error resetting deadline as there is nothing that can reasonably be done if it fails.
c.conn.SetReadDeadline(c.readDeadline)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
err = ErrWouldBlock
}
}
}()
}
return c.conn.Read(b)
}
func (c *NetConn) realNonblockingRead(b []byte) (n int, err error) {
var funcErr error
err = c.rawConn.Read(func(fd uintptr) (done bool) {
n, funcErr = syscall.Read(int(fd), b)
return true
})
if err == nil && funcErr != nil {
if errors.Is(funcErr, syscall.EWOULDBLOCK) {
err = ErrWouldBlock
} else {
err = funcErr
}
}
if err != nil {
// n may be -1 when an error occurs.
if n < 0 {
n = 0
}
return n, err
}
// syscall read did not return an error and 0 bytes were read means EOF.
if n == 0 {
return 0, io.EOF
}
return n, nil
}
// syscall.Conn is interface
// TLSClient establishes a TLS connection as a client over conn using config.
//
// To avoid the first Read on the returned *TLSConn also triggering a Write due to the TLS handshake and thereby
// potentially causing a read and write deadlines to behave unexpectedly, Handshake is called explicitly before the
// *TLSConn is returned.
func TLSClient(conn *NetConn, config *tls.Config) (*TLSConn, error) {
tc := tls.Client(conn, config)
err := tc.Handshake()
if err != nil {
return nil, err
}
// Ensure last written part of Handshake is actually sent.
err = conn.Flush()
if err != nil {
return nil, err
}
return &TLSConn{
tlsConn: tc,
nbConn: conn,
}, nil
}
// TLSConn is a TLS wrapper around a *Conn. It works around a temporary write error (such as a timeout) being fatal to a
// tls.Conn.
type TLSConn struct {
tlsConn *tls.Conn
nbConn *NetConn
}
func (tc *TLSConn) Read(b []byte) (n int, err error) { return tc.tlsConn.Read(b) }
func (tc *TLSConn) Write(b []byte) (n int, err error) { return tc.tlsConn.Write(b) }
func (tc *TLSConn) BufferReadUntilBlock() error { return tc.nbConn.BufferReadUntilBlock() }
func (tc *TLSConn) Flush() error { return tc.nbConn.Flush() }
func (tc *TLSConn) LocalAddr() net.Addr { return tc.tlsConn.LocalAddr() }
func (tc *TLSConn) RemoteAddr() net.Addr { return tc.tlsConn.RemoteAddr() }
func (tc *TLSConn) Close() error {
// tls.Conn.closeNotify() sets a 5 second deadline to avoid blocking, sends a TLS alert close notification, and then
// sets the deadline to now. This causes NetConn's Close not to be able to flush the write buffer. Instead we set our
// own 5 second deadline then make all set deadlines no-op.
tc.tlsConn.SetDeadline(time.Now().Add(time.Second * 5))
tc.tlsConn.SetDeadline(disableSetDeadlineDeadline)
return tc.tlsConn.Close()
}
func (tc *TLSConn) SetDeadline(t time.Time) error { return tc.tlsConn.SetDeadline(t) }
func (tc *TLSConn) SetReadDeadline(t time.Time) error { return tc.tlsConn.SetReadDeadline(t) }
func (tc *TLSConn) SetWriteDeadline(t time.Time) error { return tc.tlsConn.SetWriteDeadline(t) }

View File

@ -1,584 +0,0 @@
package nbconn_test
import (
"crypto/tls"
"errors"
"io"
"net"
"strings"
"testing"
"time"
"github.com/jackc/pgx/v5/internal/nbconn"
"github.com/stretchr/testify/require"
)
// Test keys generated with:
//
// $ openssl req -x509 -newkey rsa:2048 -keyout key.pem -out cert.pem -sha256 -nodes -days 20000 -subj '/CN=localhost'
var testTLSPublicKey = []byte(`-----BEGIN CERTIFICATE-----
MIICpjCCAY4CCQCjQKYdUDQzKDANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls
b2NhbGhvc3QwIBcNMjIwNjA0MTY1MzE2WhgPMjA3NzAzMDcxNjUzMTZaMBQxEjAQ
BgNVBAMMCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
ALHbOu80cfSPufKTZsKf3E5rCXHeIHjaIbgHEXA2SW/n77U8oZX518s+27FO0sK5
yA0WnEIwY34PU359sNR5KelARGnaeh3HdaGm1nuyyxBtwwAqIuM0UxGAMF/mQ4lT
caZPxG+7WlYDqnE3eVXUtG4c+T7t5qKAB3MtfbzKFSjczkWkroi6cTypmHArGghT
0VWWVu0s9oNp5q8iWchY2o9f0aIjmKv6FgtilO+geev+4U+QvtvrziR5BO3/3EgW
c5TUVcf+lwkvp8ziXvargmjjnNTyeF37y4KpFcex0v7z7hSrUK4zU0+xRn7Bp17v
7gzj0xN+HCsUW1cjPFNezX0CAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAbEBzewzg
Z5F+BqMSxP3HkMCkLLH0N9q0/DkZaVyZ38vrjcjaDYuabq28kA2d5dc5jxsQpvTw
HTGqSv1ZxJP3pBFv6jLSh8xaM6tUkk482Q6DnZGh97CD4yup/yJzkn5nv9OHtZ9g
TnaQeeXgOz0o5Zq9IpzHJb19ysya3UCIK8oKXbSO4Qd168seCq75V2BFHDpmejjk
D92eT6WODlzzvZbhzA1F3/cUilZdhbQtJMqdecKvD+yrBpzGVqzhWQsXwsRAU1fB
hShx+D14zUGM2l4wlVzOAuGh4ZL7x3AjJsc86TsCavTspS0Xl51j+mRbiULq7G7Y
E7ZYmaKTMOhvkg==
-----END CERTIFICATE-----`)
// The strings.ReplaceAll is used to placate any secret scanners that would squawk if they saw a private key embedded in
// source code.
var testTLSPrivateKey = []byte(strings.ReplaceAll(`-----BEGIN TESTING KEY-----
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCx2zrvNHH0j7ny
k2bCn9xOawlx3iB42iG4BxFwNklv5++1PKGV+dfLPtuxTtLCucgNFpxCMGN+D1N+
fbDUeSnpQERp2nodx3WhptZ7sssQbcMAKiLjNFMRgDBf5kOJU3GmT8Rvu1pWA6px
N3lV1LRuHPk+7eaigAdzLX28yhUo3M5FpK6IunE8qZhwKxoIU9FVllbtLPaDaeav
IlnIWNqPX9GiI5ir+hYLYpTvoHnr/uFPkL7b684keQTt/9xIFnOU1FXH/pcJL6fM
4l72q4Jo45zU8nhd+8uCqRXHsdL+8+4Uq1CuM1NPsUZ+wade7+4M49MTfhwrFFtX
IzxTXs19AgMBAAECggEBAJcHt5ARVQN8WUbobMawwX/F3QtYuPJnKWMAfYpwTwQ8
TI32orCcrObmxeBXMxowcPTMUnzSYmpV0W0EhvimuzRbYr0Qzcoj6nwPFOuN9GpL
CuBE58NQV4nw9SM6gfdHaKb17bWDvz5zdnUVym9cZKts5yrNEqDDX5Aq/S8n27gJ
/qheXwSxwETVO6kMEW1ndNIWDP8DPQ0E4O//RuMZwxpnZdnjGKkdVNy8I1BpgDgn
lwgkE3H3IciASki1GYXoyvrIiRwMQVzvYD2zcgwK9OZSjZe0TGwAGa+eQdbs3A9I
Ir1kYn6ZMGMRFJA2XHJW3hMZdWB/t2xMBGy75Uv9sAECgYEA1o+oRUYwwQ1MwBo9
YA6c00KjhFgrjdzyKPQrN14Q0dw5ErqRkhp2cs7BRdCDTDrjAegPc3Otg7uMa1vp
RgU/C72jwzFLYATvn+RLGRYRyqIE+bQ22/lLnXTrp4DCfdMrqWuQbIYouGHqfQrq
MfdtSUpQ6VZCi9zHehXOYwBMvQECgYEA1DTQFpe+tndIFmguxxaBwDltoPh5omzd
3vA7iFct2+UYk5W9shfAekAaZk2WufKmmC3OfBWYyIaJ7QwQpuGDS3zwjy6WFMTE
Otp2CypFCVahwHcvn2jYHmDMT0k0Pt6X2S3GAyWTyEPv7mAfKR1OWUYi7ZgdXpt0
TtL3Z3JyhH0CgYEAwveHUGuXodUUCPvPCZo9pzrGm1wDN8WtxskY/Bbd8dTLh9lA
riKdv3Vg6q+un3ZjETht0dsrsKib0HKUZqwdve11AcmpVHcnx4MLOqBzSk4vdzfr
IbhGna3A9VRrZyqcYjb75aGDHwjaqwVgCkdrZ03AeEeJ8M2N9cIa6Js9IAECgYBu
nlU24cVdspJWc9qml3ntrUITnlMxs1R5KXuvF9rk/OixzmYDV1RTpeTdHWcL6Yyk
WYSAtHVfWpq9ggOQKpBZonh3+w3rJ6MvFsBgE5nHQ2ywOrENhQbb1xPJ5NwiRcCc
Srsk2srNo3SIK30y3n8AFIqSljABKEIZ8Olc+JDvtQKBgQCiKz43zI6a0HscgZ77
DCBduWP4nk8BM7QTFxs9VypjrylMDGGtTKHc5BLA5fNZw97Hb7pcicN7/IbUnQUD
pz01y53wMSTJs0ocAxkYvUc5laF+vMsLpG2vp8f35w8uKuO7+vm5LAjUsPd099jG
2qWm8jTPeDC3sq+67s2oojHf+Q==
-----END TESTING KEY-----`, "TESTING KEY", "PRIVATE KEY"))
func testVariants(t *testing.T, f func(t *testing.T, local nbconn.Conn, remote net.Conn)) {
for _, tt := range []struct {
name string
makeConns func(t *testing.T) (local, remote net.Conn)
useTLS bool
fakeNonBlockingIO bool
}{
{
name: "Pipe",
makeConns: makePipeConns,
useTLS: false,
fakeNonBlockingIO: true,
},
{
name: "TCP with Fake Non-blocking IO",
makeConns: makeTCPConns,
useTLS: false,
fakeNonBlockingIO: true,
},
{
name: "TLS over TCP with Fake Non-blocking IO",
makeConns: makeTCPConns,
useTLS: true,
fakeNonBlockingIO: true,
},
{
name: "TCP with Real Non-blocking IO",
makeConns: makeTCPConns,
useTLS: false,
fakeNonBlockingIO: false,
},
{
name: "TLS over TCP with Real Non-blocking IO",
makeConns: makeTCPConns,
useTLS: true,
fakeNonBlockingIO: false,
},
} {
t.Run(tt.name, func(t *testing.T) {
local, remote := tt.makeConns(t)
// Just to be sure both ends get closed. Also, it retains a reference so one side of the connection doesn't get
// garbage collected. This could happen when a test is testing against a non-responsive remote. Since it never
// uses remote it may be garbage collected leading to the connection being closed.
defer local.Close()
defer remote.Close()
var conn nbconn.Conn
netConn := nbconn.NewNetConn(local, tt.fakeNonBlockingIO)
if tt.useTLS {
cert, err := tls.X509KeyPair(testTLSPublicKey, testTLSPrivateKey)
require.NoError(t, err)
tlsServer := tls.Server(remote, &tls.Config{
Certificates: []tls.Certificate{cert},
})
serverTLSHandshakeChan := make(chan error)
go func() {
err := tlsServer.Handshake()
serverTLSHandshakeChan <- err
}()
tlsConn, err := nbconn.TLSClient(netConn, &tls.Config{InsecureSkipVerify: true})
require.NoError(t, err)
conn = tlsConn
err = <-serverTLSHandshakeChan
require.NoError(t, err)
remote = tlsServer
} else {
conn = netConn
}
f(t, conn, remote)
})
}
}
// makePipeConns returns a connected pair of net.Conns created with net.Pipe(). It is entirely synchronous so it is
// useful for testing an exact sequence of reads and writes with the underlying connection blocking.
func makePipeConns(t *testing.T) (local, remote net.Conn) {
local, remote = net.Pipe()
t.Cleanup(func() {
local.Close()
remote.Close()
})
return local, remote
}
// makeTCPConns returns a connected pair of net.Conns running over TCP on localhost.
func makeTCPConns(t *testing.T) (local, remote net.Conn) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
type acceptResultT struct {
conn net.Conn
err error
}
acceptChan := make(chan acceptResultT)
go func() {
conn, err := ln.Accept()
acceptChan <- acceptResultT{conn: conn, err: err}
}()
local, err = net.Dial("tcp", ln.Addr().String())
require.NoError(t, err)
acceptResult := <-acceptChan
require.NoError(t, acceptResult.err)
remote = acceptResult.conn
return local, remote
}
func TestWriteIsBuffered(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
// net.Pipe is synchronous so the Write would block if not buffered.
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
errChan := make(chan error, 1)
go func() {
err := conn.Flush()
errChan <- err
}()
readBuf := make([]byte, len(writeBuf))
_, err = remote.Read(readBuf)
require.NoError(t, err)
require.NoError(t, <-errChan)
})
}
func TestSetWriteDeadlineDoesNotBlockWrite(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
err := conn.SetWriteDeadline(time.Now())
require.NoError(t, err)
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
})
}
func TestReadFlushesWriteBuffer(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
errChan := make(chan error, 2)
go func() {
readBuf := make([]byte, len(writeBuf))
_, err := remote.Read(readBuf)
errChan <- err
_, err = remote.Write([]byte("okay"))
errChan <- err
}()
readBuf := make([]byte, 4)
_, err = conn.Read(readBuf)
require.NoError(t, err)
require.Equal(t, []byte("okay"), readBuf)
require.NoError(t, <-errChan)
require.NoError(t, <-errChan)
})
}
func TestCloseFlushesWriteBuffer(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
writeBuf := []byte("test")
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
errChan := make(chan error, 1)
go func() {
readBuf := make([]byte, len(writeBuf))
_, err := remote.Read(readBuf)
errChan <- err
}()
err = conn.Close()
require.NoError(t, err)
require.NoError(t, <-errChan)
})
}
// This test exercises the non-blocking write path. Because writes are buffered it is difficult trigger this with
// certainty and visibility. So this test tries to trigger what would otherwise be a deadlock by both sides writing
// large values.
func TestInternalNonBlockingWrite(t *testing.T) {
const deadlockSize = 4 * 1024 * 1024
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
writeBuf := make([]byte, deadlockSize)
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, deadlockSize, n)
errChan := make(chan error, 1)
go func() {
remoteWriteBuf := make([]byte, deadlockSize)
_, err := remote.Write(remoteWriteBuf)
if err != nil {
errChan <- err
return
}
readBuf := make([]byte, deadlockSize)
_, err = io.ReadFull(remote, readBuf)
errChan <- err
}()
readBuf := make([]byte, deadlockSize)
_, err = io.ReadFull(conn, readBuf)
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
require.NoError(t, <-errChan)
})
}
func TestInternalNonBlockingWriteWithDeadline(t *testing.T) {
const deadlockSize = 4 * 1024 * 1024
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
writeBuf := make([]byte, deadlockSize)
n, err := conn.Write(writeBuf)
require.NoError(t, err)
require.EqualValues(t, deadlockSize, n)
err = conn.SetDeadline(time.Now().Add(100 * time.Millisecond))
require.NoError(t, err)
err = conn.Flush()
require.Error(t, err)
require.Contains(t, err.Error(), "i/o timeout")
})
}
func TestNonBlockingRead(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
err := conn.SetReadDeadline(nbconn.NonBlockingDeadline)
require.NoError(t, err)
buf := make([]byte, 4)
n, err := conn.Read(buf)
require.ErrorIs(t, err, nbconn.ErrWouldBlock)
require.EqualValues(t, 0, n)
errChan := make(chan error, 1)
go func() {
_, err := remote.Write([]byte("okay"))
errChan <- err
}()
err = conn.SetReadDeadline(time.Time{})
require.NoError(t, err)
n, err = conn.Read(buf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
})
}
func TestBufferNonBlockingRead(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
err := conn.BufferReadUntilBlock()
require.NoError(t, err)
errChan := make(chan error, 1)
go func() {
_, err := remote.Write([]byte("okay"))
errChan <- err
}()
for i := 0; i < 1000; i++ {
err = conn.BufferReadUntilBlock()
if !errors.Is(err, nbconn.ErrWouldBlock) {
break
}
time.Sleep(time.Millisecond)
}
require.NoError(t, err)
buf := make([]byte, 4)
n, err := conn.Read(buf)
require.NoError(t, err)
require.EqualValues(t, 4, n)
require.Equal(t, []byte("okay"), buf)
})
}
func TestReadPreviouslyBuffered(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
errChan := make(chan error, 1)
go func() {
err := func() error {
_, err := remote.Write([]byte("alpha"))
if err != nil {
return err
}
readBuf := make([]byte, 4)
_, err = remote.Read(readBuf)
if err != nil {
return err
}
return nil
}()
errChan <- err
}()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
err = conn.Flush()
require.NoError(t, err)
readBuf := make([]byte, 5)
n, err := conn.Read(readBuf)
require.NoError(t, err)
require.EqualValues(t, 5, n)
require.Equal(t, []byte("alpha"), readBuf)
})
}
func TestReadMoreThanPreviouslyBufferedDoesNotBlock(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
errChan := make(chan error, 1)
go func() {
err := func() error {
_, err := remote.Write([]byte("alpha"))
if err != nil {
return err
}
readBuf := make([]byte, 4)
_, err = remote.Read(readBuf)
if err != nil {
return err
}
return nil
}()
errChan <- err
}()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
err = conn.Flush()
require.NoError(t, err)
readBuf := make([]byte, 10)
n, err := conn.Read(readBuf)
require.NoError(t, err)
require.EqualValues(t, 5, n)
require.Equal(t, []byte("alpha"), readBuf[:n])
})
}
func TestReadPreviouslyBufferedPartialRead(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
errChan := make(chan error, 1)
go func() {
err := func() error {
_, err := remote.Write([]byte("alpha"))
if err != nil {
return err
}
readBuf := make([]byte, 4)
_, err = remote.Read(readBuf)
if err != nil {
return err
}
return nil
}()
errChan <- err
}()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
err = conn.Flush()
require.NoError(t, err)
readBuf := make([]byte, 2)
n, err := conn.Read(readBuf)
require.NoError(t, err)
require.EqualValues(t, 2, n)
require.Equal(t, []byte("al"), readBuf)
readBuf = make([]byte, 3)
n, err = conn.Read(readBuf)
require.NoError(t, err)
require.EqualValues(t, 3, n)
require.Equal(t, []byte("pha"), readBuf)
})
}
func TestReadMultiplePreviouslyBuffered(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
errChan := make(chan error, 1)
go func() {
err := func() error {
_, err := remote.Write([]byte("alpha"))
if err != nil {
return err
}
_, err = remote.Write([]byte("beta"))
if err != nil {
return err
}
readBuf := make([]byte, 4)
_, err = remote.Read(readBuf)
if err != nil {
return err
}
return nil
}()
errChan <- err
}()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
err = conn.Flush()
require.NoError(t, err)
readBuf := make([]byte, 9)
n, err := io.ReadFull(conn, readBuf)
require.NoError(t, err)
require.EqualValues(t, 9, n)
require.Equal(t, []byte("alphabeta"), readBuf)
})
}
func TestReadPreviouslyBufferedAndReadMore(t *testing.T) {
testVariants(t, func(t *testing.T, conn nbconn.Conn, remote net.Conn) {
flushCompleteChan := make(chan struct{})
errChan := make(chan error, 1)
go func() {
err := func() error {
_, err := remote.Write([]byte("alpha"))
if err != nil {
return err
}
readBuf := make([]byte, 4)
_, err = remote.Read(readBuf)
if err != nil {
return err
}
<-flushCompleteChan
_, err = remote.Write([]byte("beta"))
if err != nil {
return err
}
return nil
}()
errChan <- err
}()
_, err := conn.Write([]byte("test"))
require.NoError(t, err)
// Because net.Pipe() is synchronous conn.Flush must buffer a read.
err = conn.Flush()
require.NoError(t, err)
close(flushCompleteChan)
readBuf := make([]byte, 9)
n, err := io.ReadFull(conn, readBuf)
require.NoError(t, err)
require.EqualValues(t, 9, n)
require.Equal(t, []byte("alphabeta"), readBuf)
err = <-errChan
require.NoError(t, err)
})
}

View File

@ -23,7 +23,7 @@ func TestScript(t *testing.T) {
script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "select 42"})) script.Steps = append(script.Steps, pgmock.ExpectMessage(&pgproto3.Query{String: "select 42"}))
script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.RowDescription{ script.Steps = append(script.Steps, pgmock.SendMessage(&pgproto3.RowDescription{
Fields: []pgproto3.FieldDescription{ Fields: []pgproto3.FieldDescription{
pgproto3.FieldDescription{ {
Name: []byte("?column?"), Name: []byte("?column?"),
TableOID: 0, TableOID: 0,
TableAttributeNumber: 0, TableAttributeNumber: 0,
@ -69,9 +69,7 @@ func TestScript(t *testing.T) {
} }
}() }()
parts := strings.Split(ln.Addr().String(), ":") host, port, _ := strings.Cut(ln.Addr().String(), ":")
host := parts[0]
port := parts[1]
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port) connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)

View File

@ -0,0 +1,60 @@
#!/usr/bin/env bash
current_branch=$(git rev-parse --abbrev-ref HEAD)
if [ "$current_branch" == "HEAD" ]; then
current_branch=$(git rev-parse HEAD)
fi
restore_branch() {
echo "Restoring original branch/commit: $current_branch"
git checkout "$current_branch"
}
trap restore_branch EXIT
# Check if there are uncommitted changes
if ! git diff --quiet || ! git diff --cached --quiet; then
echo "There are uncommitted changes. Please commit or stash them before running this script."
exit 1
fi
# Ensure that at least one commit argument is passed
if [ "$#" -lt 1 ]; then
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
exit 1
fi
commits=("$@")
benchmarks_dir=benchmarks
if ! mkdir -p "${benchmarks_dir}"; then
echo "Unable to create dir for benchmarks data"
exit 1
fi
# Benchmark results
bench_files=()
# Run benchmark for each listed commit
for i in "${!commits[@]}"; do
commit="${commits[i]}"
git checkout "$commit" || {
echo "Failed to checkout $commit"
exit 1
}
# Sanitized commmit message
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
# Benchmark data will go there
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
if ! go test -bench=. -count=10 >"$bench_file"; then
echo "Benchmarking failed for commit $commit"
exit 1
fi
bench_files+=("$bench_file")
done
# go install golang.org/x/perf/cmd/benchstat[@latest]
benchstat "${bench_files[@]}"

View File

@ -4,8 +4,10 @@ import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"unicode/utf8" "unicode/utf8"
) )
@ -18,44 +20,81 @@ type Query struct {
Parts []Part Parts []Part
} }
// utf.DecodeRune returns the utf8.RuneError for errors. But that is actually rune U+FFFD -- the unicode replacement
// character. utf8.RuneError is not an error if it is also width 3.
//
// https://github.com/jackc/pgx/issues/1380
const replacementcharacterwidth = 3
const maxBufSize = 16384 // 16 Ki
var bufPool = &pool[*bytes.Buffer]{
new: func() *bytes.Buffer {
return &bytes.Buffer{}
},
reset: func(b *bytes.Buffer) bool {
n := b.Len()
b.Reset()
return n < maxBufSize
},
}
var null = []byte("null")
func (q *Query) Sanitize(args ...any) (string, error) { func (q *Query) Sanitize(args ...any) (string, error) {
argUse := make([]bool, len(args)) argUse := make([]bool, len(args))
buf := &bytes.Buffer{} buf := bufPool.get()
defer bufPool.put(buf)
for _, part := range q.Parts { for _, part := range q.Parts {
var str string
switch part := part.(type) { switch part := part.(type) {
case string: case string:
str = part buf.WriteString(part)
case int: case int:
argIdx := part - 1 argIdx := part - 1
var p []byte
if argIdx < 0 {
return "", fmt.Errorf("first sql argument must be > 0")
}
if argIdx >= len(args) { if argIdx >= len(args) {
return "", fmt.Errorf("insufficient arguments") return "", fmt.Errorf("insufficient arguments")
} }
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
buf.WriteByte(' ')
arg := args[argIdx] arg := args[argIdx]
switch arg := arg.(type) { switch arg := arg.(type) {
case nil: case nil:
str = "null" p = null
case int64: case int64:
str = strconv.FormatInt(arg, 10) p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
case float64: case float64:
str = strconv.FormatFloat(arg, 'f', -1, 64) p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
case bool: case bool:
str = strconv.FormatBool(arg) p = strconv.AppendBool(buf.AvailableBuffer(), arg)
case []byte: case []byte:
str = QuoteBytes(arg) p = QuoteBytes(buf.AvailableBuffer(), arg)
case string: case string:
str = QuoteString(arg) p = QuoteString(buf.AvailableBuffer(), arg)
case time.Time: case time.Time:
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'") p = arg.Truncate(time.Microsecond).
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
default: default:
return "", fmt.Errorf("invalid arg type: %T", arg) return "", fmt.Errorf("invalid arg type: %T", arg)
} }
argUse[argIdx] = true argUse[argIdx] = true
buf.Write(p)
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
buf.WriteByte(' ')
default: default:
return "", fmt.Errorf("invalid Part type: %T", part) return "", fmt.Errorf("invalid Part type: %T", part)
} }
buf.WriteString(str)
} }
for i, used := range argUse { for i, used := range argUse {
@ -67,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) {
} }
func NewQuery(sql string) (*Query, error) { func NewQuery(sql string) (*Query, error) {
l := &sqlLexer{ query := &Query{}
src: sql, query.init(sql)
stateFn: rawState,
return query, nil
} }
var sqlLexerPool = &pool[*sqlLexer]{
new: func() *sqlLexer {
return &sqlLexer{}
},
reset: func(sl *sqlLexer) bool {
*sl = sqlLexer{}
return true
},
}
func (q *Query) init(sql string) {
parts := q.Parts[:0]
if parts == nil {
// dirty, but fast heuristic to preallocate for ~90% usecases
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
parts = make([]Part, 0, n)
}
l := sqlLexerPool.get()
defer sqlLexerPool.put(l)
l.src = sql
l.stateFn = rawState
l.parts = parts
for l.stateFn != nil { for l.stateFn != nil {
l.stateFn = l.stateFn(l) l.stateFn = l.stateFn(l)
} }
query := &Query{Parts: l.parts} q.Parts = l.parts
return query, nil
} }
func QuoteString(str string) string { func QuoteString(dst []byte, str string) []byte {
return "'" + strings.ReplaceAll(str, "'", "''") + "'" const quote = '\''
// Preallocate space for the worst case scenario
dst = slices.Grow(dst, len(str)*2+2)
// Add opening quote
dst = append(dst, quote)
// Iterate through the string without allocating
for i := 0; i < len(str); i++ {
if str[i] == quote {
dst = append(dst, quote, quote)
} else {
dst = append(dst, str[i])
}
} }
func QuoteBytes(buf []byte) string { // Add closing quote
return `'\x` + hex.EncodeToString(buf) + "'" dst = append(dst, quote)
return dst
}
func QuoteBytes(dst, buf []byte) []byte {
if len(buf) == 0 {
return append(dst, `'\x'`...)
}
// Calculate required length
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
// Ensure dst has enough capacity
if cap(dst)-len(dst) < requiredLen {
newDst := make([]byte, len(dst), len(dst)+requiredLen)
copy(newDst, dst)
dst = newDst
}
// Record original length and extend slice
origLen := len(dst)
dst = dst[:origLen+requiredLen]
// Add prefix
dst[origLen] = '\''
dst[origLen+1] = '\\'
dst[origLen+2] = 'x'
// Encode bytes directly into dst
hex.Encode(dst[origLen+3:len(dst)-1], buf)
// Add suffix
dst[len(dst)-1] = '\''
return dst
} }
type sqlLexer struct { type sqlLexer struct {
@ -138,6 +250,7 @@ func rawState(l *sqlLexer) stateFn {
return multilineCommentState return multilineCommentState
} }
case utf8.RuneError: case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 { if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos]) l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos l.start = l.pos
@ -146,6 +259,7 @@ func rawState(l *sqlLexer) stateFn {
} }
} }
} }
}
func singleQuoteState(l *sqlLexer) stateFn { func singleQuoteState(l *sqlLexer) stateFn {
for { for {
@ -160,6 +274,7 @@ func singleQuoteState(l *sqlLexer) stateFn {
} }
l.pos += width l.pos += width
case utf8.RuneError: case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 { if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos]) l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos l.start = l.pos
@ -168,6 +283,7 @@ func singleQuoteState(l *sqlLexer) stateFn {
} }
} }
} }
}
func doubleQuoteState(l *sqlLexer) stateFn { func doubleQuoteState(l *sqlLexer) stateFn {
for { for {
@ -182,6 +298,7 @@ func doubleQuoteState(l *sqlLexer) stateFn {
} }
l.pos += width l.pos += width
case utf8.RuneError: case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 { if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos]) l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos l.start = l.pos
@ -190,6 +307,7 @@ func doubleQuoteState(l *sqlLexer) stateFn {
} }
} }
} }
}
// placeholderState consumes a placeholder value. The $ must have already has // placeholderState consumes a placeholder value. The $ must have already has
// already been consumed. The first rune must be a digit. // already been consumed. The first rune must be a digit.
@ -228,6 +346,7 @@ func escapeStringState(l *sqlLexer) stateFn {
} }
l.pos += width l.pos += width
case utf8.RuneError: case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 { if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos]) l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos l.start = l.pos
@ -236,6 +355,7 @@ func escapeStringState(l *sqlLexer) stateFn {
} }
} }
} }
}
func oneLineCommentState(l *sqlLexer) stateFn { func oneLineCommentState(l *sqlLexer) stateFn {
for { for {
@ -249,6 +369,7 @@ func oneLineCommentState(l *sqlLexer) stateFn {
case '\n', '\r': case '\n', '\r':
return rawState return rawState
case utf8.RuneError: case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 { if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos]) l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos l.start = l.pos
@ -257,6 +378,7 @@ func oneLineCommentState(l *sqlLexer) stateFn {
} }
} }
} }
}
func multilineCommentState(l *sqlLexer) stateFn { func multilineCommentState(l *sqlLexer) stateFn {
for { for {
@ -283,6 +405,7 @@ func multilineCommentState(l *sqlLexer) stateFn {
l.nested-- l.nested--
case utf8.RuneError: case utf8.RuneError:
if width != replacementcharacterwidth {
if l.pos-l.start > 0 { if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos]) l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos l.start = l.pos
@ -291,14 +414,47 @@ func multilineCommentState(l *sqlLexer) stateFn {
} }
} }
} }
}
var queryPool = &pool[*Query]{
new: func() *Query {
return &Query{}
},
reset: func(q *Query) bool {
n := len(q.Parts)
q.Parts = q.Parts[:0]
return n < 64 // drop too large queries
},
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args // SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is // as necessary. This function is only safe when standard_conforming_strings is
// on. // on.
func SanitizeSQL(sql string, args ...any) (string, error) { func SanitizeSQL(sql string, args ...any) (string, error) {
query, err := NewQuery(sql) query := queryPool.get()
if err != nil { query.init(sql)
return "", err defer queryPool.put(query)
}
return query.Sanitize(args...) return query.Sanitize(args...)
} }
type pool[E any] struct {
p sync.Pool
new func() E
reset func(E) bool
}
func (pool *pool[E]) get() E {
v, ok := pool.p.Get().(E)
if !ok {
v = pool.new()
}
return v
}
func (p *pool[E]) put(v E) {
if p.reset(v) {
p.p.Put(v)
}
}

View File

@ -0,0 +1,62 @@
// sanitize_benchmark_test.go
package sanitize_test
import (
"testing"
"time"
"github.com/jackc/pgx/v5/internal/sanitize"
)
var benchmarkSanitizeResult string
const benchmarkQuery = "" +
`SELECT *
FROM "water_containers"
WHERE NOT "id" = $1 -- int64
AND "tags" NOT IN $2 -- nil
AND "volume" > $3 -- float64
AND "transportable" = $4 -- bool
AND position($5 IN "sign") -- bytes
AND "label" LIKE $6 -- string
AND "created_at" > $7; -- time.Time`
var benchmarkArgs = []any{
int64(12345),
nil,
float64(500),
true,
[]byte("8BADF00D"),
"kombucha's han'dy awokowa",
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
}
func BenchmarkSanitize(b *testing.B) {
query, err := sanitize.NewQuery(benchmarkQuery)
if err != nil {
b.Fatalf("failed to create query: %v", err)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize query: %v", err)
}
}
}
var benchmarkNewSQLResult string
func BenchmarkSanitizeSQL(b *testing.B) {
b.ReportAllocs()
var err error
for i := 0; i < b.N; i++ {
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize SQL: %v", err)
}
}
}

View File

@ -0,0 +1,55 @@
package sanitize_test
import (
"strings"
"testing"
"github.com/jackc/pgx/v5/internal/sanitize"
)
func FuzzQuoteString(f *testing.F) {
const prefix = "prefix"
f.Add("new\nline")
f.Add("sample text")
f.Add("sample q'u'o't'e's")
f.Add("select 'quoted $42', $1")
f.Fuzz(func(t *testing.T, input string) {
got := string(sanitize.QuoteString([]byte(prefix), input))
want := oldQuoteString(input)
quoted, ok := strings.CutPrefix(got, prefix)
if !ok {
t.Fatalf("result has no prefix")
}
if want != quoted {
t.Errorf("got %q", got)
t.Fatalf("want %q", want)
}
})
}
func FuzzQuoteBytes(f *testing.F) {
const prefix = "prefix"
f.Add([]byte(nil))
f.Add([]byte("\n"))
f.Add([]byte("sample text"))
f.Add([]byte("sample q'u'o't'e's"))
f.Add([]byte("select 'quoted $42', $1"))
f.Fuzz(func(t *testing.T, input []byte) {
got := string(sanitize.QuoteBytes([]byte(prefix), input))
want := oldQuoteBytes(input)
quoted, ok := strings.CutPrefix(got, prefix)
if !ok {
t.Fatalf("result has no prefix")
}
if want != quoted {
t.Errorf("got %q", got)
t.Fatalf("want %q", want)
}
})
}

View File

@ -1,6 +1,8 @@
package sanitize_test package sanitize_test
import ( import (
"encoding/hex"
"strings"
"testing" "testing"
"time" "time"
@ -88,6 +90,16 @@ func TestNewQuery(t *testing.T) {
sql: "select 42, -- \\nis a Deep Thought's favorite number\r$1", sql: "select 42, -- \\nis a Deep Thought's favorite number\r$1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\r", 1}}, expected: sanitize.Query{Parts: []sanitize.Part{"select 42, -- \\nis a Deep Thought's favorite number\r", 1}},
}, },
{
// https://github.com/jackc/pgx/issues/1380
sql: "select 'hello w<>rld'",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello w<>rld'"}},
},
{
// Unterminated quoted string
sql: "select 'hello world",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'hello world"}},
},
} }
for i, tt := range successTests { for i, tt := range successTests {
@ -164,6 +176,16 @@ func TestQuerySanitize(t *testing.T) {
args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)}, args: []any{time.Date(2020, time.March, 1, 23, 59, 59, 999999999, time.UTC)},
expected: `insert '2020-03-01 23:59:59.999999Z' `, expected: `insert '2020-03-01 23:59:59.999999Z' `,
}, },
{
query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
args: []any{int64(-1)},
expected: `select 1- -1 `,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select 1-", 1}},
args: []any{float64(-1)},
expected: `select 1- -1 `,
},
} }
for i, tt := range successfulTests { for i, tt := range successfulTests {
@ -207,3 +229,55 @@ func TestQuerySanitize(t *testing.T) {
} }
} }
} }
func TestQuoteString(t *testing.T) {
tc := func(name, input string) {
t.Run(name, func(t *testing.T) {
t.Parallel()
got := string(sanitize.QuoteString(nil, input))
want := oldQuoteString(input)
if got != want {
t.Errorf("got: %s", got)
t.Fatalf("want: %s", want)
}
})
}
tc("empty", "")
tc("text", "abcd")
tc("with quotes", `one's hat is always a cat`)
}
// This function was used before optimizations.
// You should keep for testing purposes - we want to ensure there are no breaking changes.
func oldQuoteString(str string) string {
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
}
func TestQuoteBytes(t *testing.T) {
tc := func(name string, input []byte) {
t.Run(name, func(t *testing.T) {
t.Parallel()
got := string(sanitize.QuoteBytes(nil, input))
want := oldQuoteBytes(input)
if got != want {
t.Errorf("got: %s", got)
t.Fatalf("want: %s", want)
}
})
}
tc("nil", nil)
tc("empty", []byte{})
tc("text", []byte("abcd"))
}
// This function was used before optimizations.
// You should keep for testing purposes - we want to ensure there are no breaking changes.
func oldQuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
}

View File

@ -34,7 +34,8 @@ func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
} }
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache. // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or
// sd.SQL has been invalidated and HandleInvalidated has not been called yet.
func (c *LRUCache) Put(sd *pgconn.StatementDescription) { func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
if sd.SQL == "" { if sd.SQL == "" {
panic("cannot store statement description with empty SQL") panic("cannot store statement description with empty SQL")
@ -44,6 +45,13 @@ func (c *LRUCache) Put(sd *pgconn.StatementDescription) {
return return
} }
// The statement may have been invalidated but not yet handled. Do not readd it to the cache.
for _, invalidSD := range c.invalidStmts {
if invalidSD.SQL == sd.SQL {
return
}
}
if c.l.Len() == c.cap { if c.l.Len() == c.cap {
c.invalidateOldest() c.invalidateOldest()
} }
@ -73,10 +81,16 @@ func (c *LRUCache) InvalidateAll() {
c.l = list.New() c.l = list.New()
} }
func (c *LRUCache) HandleInvalidated() []*pgconn.StatementDescription { // GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
invalidStmts := c.invalidStmts func (c *LRUCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *LRUCache) RemoveInvalidated() {
c.invalidStmts = nil c.invalidStmts = nil
return invalidStmts
} }
// Len returns the number of cached prepared statement descriptions. // Len returns the number of cached prepared statement descriptions.

View File

@ -2,18 +2,17 @@
package stmtcache package stmtcache
import ( import (
"strconv" "crypto/sha256"
"sync/atomic" "encoding/hex"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
) )
var stmtCounter int64 // StatementName returns a statement name that will be stable for sql across multiple connections and program
// executions.
// NextStatementName returns a statement name that will be unique for the lifetime of the program. func StatementName(sql string) string {
func NextStatementName() string { digest := sha256.Sum256([]byte(sql))
n := atomic.AddInt64(&stmtCounter, 1) return "stmtcache_" + hex.EncodeToString(digest[0:24])
return "stmtcache_" + strconv.FormatInt(n, 10)
} }
// Cache caches statement descriptions. // Cache caches statement descriptions.
@ -30,8 +29,13 @@ type Cache interface {
// InvalidateAll invalidates all statement descriptions. // InvalidateAll invalidates all statement descriptions.
InvalidateAll() InvalidateAll()
// HandleInvalidated returns a slice of all statement descriptions invalidated since the last call to HandleInvalidated. // GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
HandleInvalidated() []*pgconn.StatementDescription GetInvalidated() []*pgconn.StatementDescription
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
RemoveInvalidated()
// Len returns the number of cached prepared statement descriptions. // Len returns the number of cached prepared statement descriptions.
Len() int Len() int
@ -39,19 +43,3 @@ type Cache interface {
// Cap returns the maximum number of cached prepared statement descriptions. // Cap returns the maximum number of cached prepared statement descriptions.
Cap() int Cap() int
} }
func IsStatementInvalid(err error) bool {
pgErr, ok := err.(*pgconn.PgError)
if !ok {
return false
}
// https://github.com/jackc/pgx/issues/1162
//
// We used to look for the message "cached plan must not change result type". However, that message can be localized.
// Unfortunately, error code "0A000" - "FEATURE NOT SUPPORTED" is used for many different errors and the only way to
// tell the difference is by the message. But all that happens is we clear a statement that we otherwise wouldn't
// have so it should be safe.
possibleInvalidCachedPlanError := pgErr.Code == "0A000"
return possibleInvalidCachedPlanError
}

View File

@ -54,10 +54,16 @@ func (c *UnlimitedCache) InvalidateAll() {
c.m = make(map[string]*pgconn.StatementDescription) c.m = make(map[string]*pgconn.StatementDescription)
} }
func (c *UnlimitedCache) HandleInvalidated() []*pgconn.StatementDescription { // GetInvalidated returns a slice of all statement descriptions invalidated since the last call to RemoveInvalidated.
invalidStmts := c.invalidStmts func (c *UnlimitedCache) GetInvalidated() []*pgconn.StatementDescription {
return c.invalidStmts
}
// RemoveInvalidated removes all invalidated statement descriptions. No other calls to Cache must be made between a
// call to GetInvalidated and RemoveInvalidated or RemoveInvalidated may remove statement descriptions that were
// never seen by the call to GetInvalidated.
func (c *UnlimitedCache) RemoveInvalidated() {
c.invalidStmts = nil c.invalidStmts = nil
return invalidStmts
} }
// Len returns the number of cached prepared statement descriptions. // Len returns the number of cached prepared statement descriptions.

View File

@ -4,8 +4,15 @@ import (
"context" "context"
"errors" "errors"
"io" "io"
"github.com/jackc/pgx/v5/pgtype"
) )
// The PostgreSQL wire protocol has a limit of 1 GB - 1 per message. See definition of
// PQ_LARGE_MESSAGE_LIMIT in the PostgreSQL source code. To allow for the other data
// in the message,maxLargeObjectMessageLength should be no larger than 1 GB - 1 KB.
var maxLargeObjectMessageLength = 1024*1024*1024 - 1024
// LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it // LargeObjects is a structure used to access the large objects API. It is only valid within the transaction where it
// was created. // was created.
// //
@ -68,32 +75,65 @@ type LargeObject struct {
// Write writes p to the large object and returns the number of bytes written and an error if not all of p was written. // Write writes p to the large object and returns the number of bytes written and an error if not all of p was written.
func (o *LargeObject) Write(p []byte) (int, error) { func (o *LargeObject) Write(p []byte) (int, error) {
nTotal := 0
for {
expected := len(p) - nTotal
if expected == 0 {
break
} else if expected > maxLargeObjectMessageLength {
expected = maxLargeObjectMessageLength
}
var n int var n int
err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p).Scan(&n) err := o.tx.QueryRow(o.ctx, "select lowrite($1, $2)", o.fd, p[nTotal:nTotal+expected]).Scan(&n)
if err != nil { if err != nil {
return n, err return nTotal, err
} }
if n < 0 { if n < 0 {
return 0, errors.New("failed to write to large object") return nTotal, errors.New("failed to write to large object")
} }
return n, nil nTotal += n
if n < expected {
return nTotal, errors.New("short write to large object")
} else if n > expected {
return nTotal, errors.New("invalid write to large object")
}
}
return nTotal, nil
} }
// Read reads up to len(p) bytes into p returning the number of bytes read. // Read reads up to len(p) bytes into p returning the number of bytes read.
func (o *LargeObject) Read(p []byte) (int, error) { func (o *LargeObject) Read(p []byte) (int, error) {
var res []byte nTotal := 0
err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, len(p)).Scan(&res) for {
copy(p, res) expected := len(p) - nTotal
if err != nil { if expected == 0 {
return len(res), err break
} else if expected > maxLargeObjectMessageLength {
expected = maxLargeObjectMessageLength
} }
if len(res) < len(p) { res := pgtype.PreallocBytes(p[nTotal:])
err = io.EOF err := o.tx.QueryRow(o.ctx, "select loread($1, $2)", o.fd, expected).Scan(&res)
// We compute expected so that it always fits into p, so it should never happen
// that PreallocBytes's ScanBytes had to allocate a new slice.
nTotal += len(res)
if err != nil {
return nTotal, err
} }
return len(res), err
if len(res) < expected {
return nTotal, io.EOF
} else if len(res) > expected {
return nTotal, errors.New("invalid read of large object")
}
}
return nTotal, nil
} }
// Seek moves the current location pointer to the new location specified by offset. // Seek moves the current location pointer to the new location specified by offset.

View File

@ -0,0 +1,20 @@
package pgx
import (
"testing"
)
// SetMaxLargeObjectMessageLength sets internal maxLargeObjectMessageLength variable
// to the given length for the duration of the test.
//
// Tests using this helper should not use t.Parallel().
func SetMaxLargeObjectMessageLength(t *testing.T, length int) {
t.Helper()
original := maxLargeObjectMessageLength
t.Cleanup(func() {
maxLargeObjectMessageLength = original
})
maxLargeObjectMessageLength = length
}

View File

@ -13,9 +13,10 @@ import (
) )
func TestLargeObjects(t *testing.T) { func TestLargeObjects(t *testing.T) {
t.Parallel() // We use a very short limit to test chunking logic.
pgx.SetMaxLargeObjectMessageLength(t, 2)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel() defer cancel()
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
@ -34,9 +35,10 @@ func TestLargeObjects(t *testing.T) {
} }
func TestLargeObjectsSimpleProtocol(t *testing.T) { func TestLargeObjectsSimpleProtocol(t *testing.T) {
t.Parallel() // We use a very short limit to test chunking logic.
pgx.SetMaxLargeObjectMessageLength(t, 2)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel() defer cancel()
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
@ -160,9 +162,10 @@ func testLargeObjects(t *testing.T, ctx context.Context, tx pgx.Tx) {
} }
func TestLargeObjectsMultipleTransactions(t *testing.T) { func TestLargeObjectsMultipleTransactions(t *testing.T) {
t.Parallel() // We use a very short limit to test chunking logic.
pgx.SetMaxLargeObjectMessageLength(t, 2)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel() defer cancel()
conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) conn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))

152
multitracer/tracer.go Normal file
View File

@ -0,0 +1,152 @@
// Package multitracer provides a Tracer that can combine several tracers into one.
package multitracer
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
// Tracer can combine several tracers into one.
// You can use New to automatically split tracers by interface.
type Tracer struct {
QueryTracers []pgx.QueryTracer
BatchTracers []pgx.BatchTracer
CopyFromTracers []pgx.CopyFromTracer
PrepareTracers []pgx.PrepareTracer
ConnectTracers []pgx.ConnectTracer
PoolAcquireTracers []pgxpool.AcquireTracer
PoolReleaseTracers []pgxpool.ReleaseTracer
}
// New returns new Tracer from tracers with automatically split tracers by interface.
func New(tracers ...pgx.QueryTracer) *Tracer {
var t Tracer
for _, tracer := range tracers {
t.QueryTracers = append(t.QueryTracers, tracer)
if batchTracer, ok := tracer.(pgx.BatchTracer); ok {
t.BatchTracers = append(t.BatchTracers, batchTracer)
}
if copyFromTracer, ok := tracer.(pgx.CopyFromTracer); ok {
t.CopyFromTracers = append(t.CopyFromTracers, copyFromTracer)
}
if prepareTracer, ok := tracer.(pgx.PrepareTracer); ok {
t.PrepareTracers = append(t.PrepareTracers, prepareTracer)
}
if connectTracer, ok := tracer.(pgx.ConnectTracer); ok {
t.ConnectTracers = append(t.ConnectTracers, connectTracer)
}
if poolAcquireTracer, ok := tracer.(pgxpool.AcquireTracer); ok {
t.PoolAcquireTracers = append(t.PoolAcquireTracers, poolAcquireTracer)
}
if poolReleaseTracer, ok := tracer.(pgxpool.ReleaseTracer); ok {
t.PoolReleaseTracers = append(t.PoolReleaseTracers, poolReleaseTracer)
}
}
return &t
}
func (t *Tracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
for _, tracer := range t.QueryTracers {
ctx = tracer.TraceQueryStart(ctx, conn, data)
}
return ctx
}
func (t *Tracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
for _, tracer := range t.QueryTracers {
tracer.TraceQueryEnd(ctx, conn, data)
}
}
func (t *Tracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
for _, tracer := range t.BatchTracers {
ctx = tracer.TraceBatchStart(ctx, conn, data)
}
return ctx
}
func (t *Tracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
for _, tracer := range t.BatchTracers {
tracer.TraceBatchQuery(ctx, conn, data)
}
}
func (t *Tracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
for _, tracer := range t.BatchTracers {
tracer.TraceBatchEnd(ctx, conn, data)
}
}
func (t *Tracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
for _, tracer := range t.CopyFromTracers {
ctx = tracer.TraceCopyFromStart(ctx, conn, data)
}
return ctx
}
func (t *Tracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
for _, tracer := range t.CopyFromTracers {
tracer.TraceCopyFromEnd(ctx, conn, data)
}
}
func (t *Tracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
for _, tracer := range t.PrepareTracers {
ctx = tracer.TracePrepareStart(ctx, conn, data)
}
return ctx
}
func (t *Tracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
for _, tracer := range t.PrepareTracers {
tracer.TracePrepareEnd(ctx, conn, data)
}
}
func (t *Tracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
for _, tracer := range t.ConnectTracers {
ctx = tracer.TraceConnectStart(ctx, data)
}
return ctx
}
func (t *Tracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
for _, tracer := range t.ConnectTracers {
tracer.TraceConnectEnd(ctx, data)
}
}
func (t *Tracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context {
for _, tracer := range t.PoolAcquireTracers {
ctx = tracer.TraceAcquireStart(ctx, pool, data)
}
return ctx
}
func (t *Tracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
for _, tracer := range t.PoolAcquireTracers {
tracer.TraceAcquireEnd(ctx, pool, data)
}
}
func (t *Tracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) {
for _, tracer := range t.PoolReleaseTracers {
tracer.TraceRelease(pool, data)
}
}

115
multitracer/tracer_test.go Normal file
View File

@ -0,0 +1,115 @@
package multitracer_test
import (
"context"
"testing"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/multitracer"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/require"
)
type testFullTracer struct{}
func (tt *testFullTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
}
func (tt *testFullTracer) TraceBatchStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchQueryData) {
}
func (tt *testFullTracer) TraceBatchEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceBatchEndData) {
}
func (tt *testFullTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
}
func (tt *testFullTracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TracePrepareEnd(ctx context.Context, conn *pgx.Conn, data pgx.TracePrepareEndData) {
}
func (tt *testFullTracer) TraceConnectStart(ctx context.Context, data pgx.TraceConnectStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TraceConnectEnd(ctx context.Context, data pgx.TraceConnectEndData) {
}
func (tt *testFullTracer) TraceAcquireStart(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireStartData) context.Context {
return ctx
}
func (tt *testFullTracer) TraceAcquireEnd(ctx context.Context, pool *pgxpool.Pool, data pgxpool.TraceAcquireEndData) {
}
func (tt *testFullTracer) TraceRelease(pool *pgxpool.Pool, data pgxpool.TraceReleaseData) {
}
type testCopyTracer struct{}
func (tt *testCopyTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
return ctx
}
func (tt *testCopyTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
}
func (tt *testCopyTracer) TraceCopyFromStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromStartData) context.Context {
return ctx
}
func (tt *testCopyTracer) TraceCopyFromEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceCopyFromEndData) {
}
func TestNew(t *testing.T) {
t.Parallel()
fullTracer := &testFullTracer{}
copyTracer := &testCopyTracer{}
mt := multitracer.New(fullTracer, copyTracer)
require.Equal(
t,
&multitracer.Tracer{
QueryTracers: []pgx.QueryTracer{
fullTracer,
copyTracer,
},
BatchTracers: []pgx.BatchTracer{
fullTracer,
},
CopyFromTracers: []pgx.CopyFromTracer{
fullTracer,
copyTracer,
},
PrepareTracers: []pgx.PrepareTracer{
fullTracer,
},
ConnectTracers: []pgx.ConnectTracer{
fullTracer,
},
PoolAcquireTracers: []pgxpool.AcquireTracer{
fullTracer,
},
PoolReleaseTracers: []pgxpool.ReleaseTracer{
fullTracer,
},
},
mt,
)
}

View File

@ -2,6 +2,7 @@ package pgx
import ( import (
"context" "context"
"fmt"
"strconv" "strconv"
"strings" "strings"
"unicode/utf8" "unicode/utf8"
@ -12,12 +13,43 @@ import (
// //
// For example, the following two queries are equivalent: // For example, the following two queries are equivalent:
// //
// conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})) // conn.Query(ctx, "select * from widgets where foo = @foo and bar = @bar", pgx.NamedArgs{"foo": 1, "bar": 2})
// conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2})) // conn.Query(ctx, "select * from widgets where foo = $1 and bar = $2", 1, 2)
//
// Named placeholders are case sensitive and must start with a letter or underscore. Subsequent characters can be
// letters, numbers, or underscores.
type NamedArgs map[string]any type NamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface. // RewriteQuery implements the QueryRewriter interface.
func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any) { func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(na, sql, false)
}
// StrictNamedArgs can be used in the same way as NamedArgs, but provided arguments are also checked to include all
// named arguments that the sql query uses, and no extra arguments.
type StrictNamedArgs map[string]any
// RewriteQuery implements the QueryRewriter interface.
func (sna StrictNamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error) {
return rewriteQuery(sna, sql, true)
}
type namedArg string
type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []any
nameToOrdinal map[namedArg]int
}
type stateFn func(*sqlLexer) stateFn
func rewriteQuery(na map[string]any, sql string, isStrict bool) (newSQL string, newArgs []any, err error) {
l := &sqlLexer{ l := &sqlLexer{
src: sql, src: sql,
stateFn: rawState, stateFn: rawState,
@ -41,27 +73,24 @@ func (na NamedArgs) RewriteQuery(ctx context.Context, conn *Conn, sql string, ar
newArgs = make([]any, len(l.nameToOrdinal)) newArgs = make([]any, len(l.nameToOrdinal))
for name, ordinal := range l.nameToOrdinal { for name, ordinal := range l.nameToOrdinal {
newArgs[ordinal-1] = na[string(name)] var found bool
newArgs[ordinal-1], found = na[string(name)]
if isStrict && !found {
return "", nil, fmt.Errorf("argument %s found in sql query but not present in StrictNamedArgs", name)
}
} }
return sb.String(), newArgs if isStrict {
for name := range na {
if _, found := l.nameToOrdinal[namedArg(name)]; !found {
return "", nil, fmt.Errorf("argument %s of StrictNamedArgs not found in sql query", name)
}
}
} }
type namedArg string return sb.String(), newArgs, nil
type sqlLexer struct {
src string
start int
pos int
nested int // multiline comment nesting level.
stateFn stateFn
parts []any
nameToOrdinal map[namedArg]int
} }
type stateFn func(*sqlLexer) stateFn
func rawState(l *sqlLexer) stateFn { func rawState(l *sqlLexer) stateFn {
for { for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:]) r, width := utf8.DecodeRuneInString(l.src[l.pos:])
@ -80,7 +109,7 @@ func rawState(l *sqlLexer) stateFn {
return doubleQuoteState return doubleQuoteState
case '@': case '@':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if isLetter(nextRune) { if isLetter(nextRune) || nextRune == '_' {
if l.pos-l.start > 0 { if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width]) l.parts = append(l.parts, l.src[l.start:l.pos-width])
} }

View File

@ -6,6 +6,7 @@ import (
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestNamedArgsRewriteQuery(t *testing.T) { func TestNamedArgsRewriteQuery(t *testing.T) {
@ -37,10 +38,10 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
expectedArgs: []any{int32(42), "foo"}, expectedArgs: []any{int32(42), "foo"},
}, },
{ {
sql: "select @Abc::int, @b_4::text", sql: "select @Abc::int, @b_4::text, @_c::int",
namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo"}, namedArgs: pgx.NamedArgs{"Abc": int32(42), "b_4": "foo", "_c": int32(1)},
expectedSQL: "select $1::int, $2::text", expectedSQL: "select $1::int, $2::text, $3::int",
expectedArgs: []any{int32(42), "foo"}, expectedArgs: []any{int32(42), "foo", int32(1)},
}, },
{ {
sql: "at end @", sql: "at end @",
@ -49,15 +50,15 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
expectedArgs: []any{}, expectedArgs: []any{},
}, },
{ {
sql: "ignores without letter after @ foo bar", sql: "ignores without valid character after @ foo bar",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "ignores without letter after @ foo bar", expectedSQL: "ignores without valid character after @ foo bar",
expectedArgs: []any{}, expectedArgs: []any{},
}, },
{ {
sql: "name must start with letter @1 foo bar", sql: "name cannot start with number @1 foo bar",
namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"}, namedArgs: pgx.NamedArgs{"a": int32(42), "b": "foo"},
expectedSQL: "name must start with letter @1 foo bar", expectedSQL: "name cannot start with number @1 foo bar",
expectedArgs: []any{}, expectedArgs: []any{},
}, },
{ {
@ -92,11 +93,70 @@ func TestNamedArgsRewriteQuery(t *testing.T) {
where id = $1;`, where id = $1;`,
expectedArgs: []any{int32(42)}, expectedArgs: []any{int32(42)},
}, },
{
sql: "extra provided argument",
namedArgs: pgx.NamedArgs{"extra": int32(1)},
expectedSQL: "extra provided argument",
expectedArgs: []any{},
},
{
sql: "@missing argument",
namedArgs: pgx.NamedArgs{},
expectedSQL: "$1 argument",
expectedArgs: []any{nil},
},
// test comments and quotes // test comments and quotes
} { } {
sql, args := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args) sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, tt.args)
require.NoError(t, err)
assert.Equalf(t, tt.expectedSQL, sql, "%d", i) assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
assert.Equalf(t, tt.expectedArgs, args, "%d", i) assert.Equalf(t, tt.expectedArgs, args, "%d", i)
} }
} }
func TestStrictNamedArgsRewriteQuery(t *testing.T) {
t.Parallel()
for i, tt := range []struct {
sql string
namedArgs pgx.StrictNamedArgs
expectedSQL string
expectedArgs []any
isExpectedError bool
}{
{
sql: "no arguments",
namedArgs: pgx.StrictNamedArgs{},
expectedSQL: "no arguments",
expectedArgs: []any{},
isExpectedError: false,
},
{
sql: "@all @matches",
namedArgs: pgx.StrictNamedArgs{"all": int32(1), "matches": int32(2)},
expectedSQL: "$1 $2",
expectedArgs: []any{int32(1), int32(2)},
isExpectedError: false,
},
{
sql: "extra provided argument",
namedArgs: pgx.StrictNamedArgs{"extra": int32(1)},
isExpectedError: true,
},
{
sql: "@missing argument",
namedArgs: pgx.StrictNamedArgs{},
isExpectedError: true,
},
} {
sql, args, err := tt.namedArgs.RewriteQuery(context.Background(), nil, tt.sql, nil)
if tt.isExpectedError {
assert.Errorf(t, err, "%d", i)
} else {
require.NoErrorf(t, err, "%d", i)
assert.Equalf(t, tt.expectedSQL, sql, "%d", i)
assert.Equalf(t, tt.expectedArgs, args, "%d", i)
}
}
}

View File

@ -26,28 +26,4 @@ if err != nil {
## Testing ## Testing
The pgconn tests require a PostgreSQL database. It will connect to the database specified in the `PGX_TEST_CONN_STRING` See CONTRIBUTING.md for setup instructions.
environment variable. The `PGX_TEST_CONN_STRING` environment variable can be a URL or DSN. In addition, the standard `PG*`
environment variables will be respected. Consider using [direnv](https://github.com/direnv/direnv) to simplify
environment variable handling.
### Example Test Environment
Connect to your PostgreSQL server and run:
```
create database pgx_test;
```
Now you can run the tests:
```bash
PGX_TEST_CONN_STRING="host=/var/run/postgresql dbname=pgx_test" go test ./...
```
### Connection and Authentication Tests
Pgconn supports multiple connection types and means of authentication. These tests are optional. They
will only run if the appropriate environment variable is set. Run `go test -v | grep SKIP` to see if any tests are being
skipped. Most developers will not need to enable these tests. See `ci/setup_test.bash` for an example set up if you need change
authentication code.

View File

@ -42,12 +42,12 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
Data: sc.clientFirstMessage(), Data: sc.clientFirstMessage(),
} }
c.frontend.Send(saslInitialResponse) c.frontend.Send(saslInitialResponse)
err = c.frontend.Flush() err = c.flushWithPotentialWriteReadDeadlock()
if err != nil { if err != nil {
return err return err
} }
// Receive server-first-message payload in a AuthenticationSASLContinue. // Receive server-first-message payload in an AuthenticationSASLContinue.
saslContinue, err := c.rxSASLContinue() saslContinue, err := c.rxSASLContinue()
if err != nil { if err != nil {
return err return err
@ -62,12 +62,12 @@ func (c *PgConn) scramAuth(serverAuthMechanisms []string) error {
Data: []byte(sc.clientFinalMessage()), Data: []byte(sc.clientFinalMessage()),
} }
c.frontend.Send(saslResponse) c.frontend.Send(saslResponse)
err = c.frontend.Flush() err = c.flushWithPotentialWriteReadDeadlock()
if err != nil { if err != nil {
return err return err
} }
// Receive server-final-message payload in a AuthenticationSASLFinal. // Receive server-final-message payload in an AuthenticationSASLFinal.
saslFinal, err := c.rxSASLFinal() saslFinal, err := c.rxSASLFinal()
if err != nil { if err != nil {
return err return err

View File

@ -53,7 +53,7 @@ func BenchmarkExec(b *testing.B) {
for _, bm := range benchmarks { for _, bm := range benchmarks {
bm := bm bm := bm
b.Run(bm.name, func(b *testing.B) { b.Run(bm.name, func(b *testing.B) {
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_DATABASE"))
require.Nil(b, err) require.Nil(b, err)
defer closeConn(b, conn) defer closeConn(b, conn)
@ -97,7 +97,7 @@ func BenchmarkExec(b *testing.B) {
} }
func BenchmarkExecPossibleToCancel(b *testing.B) { func BenchmarkExecPossibleToCancel(b *testing.B) {
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(b, err) require.Nil(b, err)
defer closeConn(b, conn) defer closeConn(b, conn)
@ -159,7 +159,7 @@ func BenchmarkExecPrepared(b *testing.B) {
for _, bm := range benchmarks { for _, bm := range benchmarks {
bm := bm bm := bm
b.Run(bm.name, func(b *testing.B) { b.Run(bm.name, func(b *testing.B) {
conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_CONN_STRING")) conn, err := pgconn.Connect(bm.ctx, os.Getenv("PGX_TEST_DATABASE"))
require.Nil(b, err) require.Nil(b, err)
defer closeConn(b, conn) defer closeConn(b, conn)
@ -197,7 +197,7 @@ func BenchmarkExecPrepared(b *testing.B) {
} }
func BenchmarkExecPreparedPossibleToCancel(b *testing.B) { func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(b, err) require.Nil(b, err)
defer closeConn(b, conn) defer closeConn(b, conn)
@ -238,7 +238,7 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
} }
// func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) { // func BenchmarkChanToSetDeadlinePossibleToCancel(b *testing.B) {
// conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) // conn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
// require.Nil(b, err) // require.Nil(b, err)
// defer closeConn(b, conn) // defer closeConn(b, conn)

View File

@ -8,7 +8,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math" "math"
"net" "net"
"net/url" "net/url"
@ -20,6 +19,7 @@ import (
"github.com/jackc/pgpassfile" "github.com/jackc/pgpassfile"
"github.com/jackc/pgservicefile" "github.com/jackc/pgservicefile"
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgproto3"
) )
@ -27,7 +27,7 @@ type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
type GetSSLPasswordFunc func(ctx context.Context) string type GetSSLPasswordFunc func(ctx context.Context) string
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by ParseConfig. A // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A
// manually initialized Config will cause ConnectConfig to panic. // manually initialized Config will cause ConnectConfig to panic.
type Config struct { type Config struct {
Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp) Host string // host (e.g. localhost) or absolute path to unix domain socket directory (e.g. /private/tmp)
@ -40,12 +40,19 @@ type Config struct {
DialFunc DialFunc // e.g. net.Dialer.DialContext DialFunc DialFunc // e.g. net.Dialer.DialContext
LookupFunc LookupFunc // e.g. net.Resolver.LookupHost LookupFunc LookupFunc // e.g. net.Resolver.LookupHost
BuildFrontend BuildFrontendFunc BuildFrontend BuildFrontendFunc
// BuildContextWatcherHandler is called to create a ContextWatcherHandler for a connection. The handler is called
// when a context passed to a PgConn method is canceled.
BuildContextWatcherHandler func(*PgConn) ctxwatch.Handler
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
KerberosSrvName string KerberosSrvName string
KerberosSpn string KerberosSpn string
Fallbacks []*FallbackConfig Fallbacks []*FallbackConfig
SSLNegotiation string // sslnegotiation=postgres or sslnegotiation=direct
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server. // ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next // It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs. // fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
@ -61,12 +68,17 @@ type Config struct {
// OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received. // OnNotification is a callback function called when a notification from the LISTEN/NOTIFY system is received.
OnNotification NotificationHandler OnNotification NotificationHandler
// OnPgError is a callback function called when a Postgres error is received by the server. The default handler will close
// the connection on any FATAL errors. If you override this handler you should call the previously set handler or ensure
// that you close on FATAL errors by returning false.
OnPgError PgErrorHandler
createdByParseConfig bool // Used to enforce created by ParseConfig rule. createdByParseConfig bool // Used to enforce created by ParseConfig rule.
} }
// ParseConfigOptions contains options that control how a config is built such as getsslpassword. // ParseConfigOptions contains options that control how a config is built such as GetSSLPassword.
type ParseConfigOptions struct { type ParseConfigOptions struct {
// GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the the libpq function // GetSSLPassword gets the password to decrypt a SSL client certificate. This is analogous to the libpq function
// PQsetSSLKeyPassHook_OpenSSL. // PQsetSSLKeyPassHook_OpenSSL.
GetSSLPassword GetSSLPasswordFunc GetSSLPassword GetSSLPasswordFunc
} }
@ -108,6 +120,14 @@ type FallbackConfig struct {
TLSConfig *tls.Config // nil disables TLS TLSConfig *tls.Config // nil disables TLS
} }
// connectOneConfig is the configuration for a single attempt to connect to a single host.
type connectOneConfig struct {
network string
address string
originalHostname string // original hostname before resolving
tlsConfig *tls.Config // nil disables TLS
}
// isAbsolutePath checks if the provided value is an absolute path either // isAbsolutePath checks if the provided value is an absolute path either
// beginning with a forward slash (as on Linux-based systems) or with a capital // beginning with a forward slash (as on Linux-based systems) or with a capital
// letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows). // letter A-Z followed by a colon and a backslash, e.g., "C:\", (as on Windows).
@ -142,11 +162,11 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It // ParseConfig builds a *Config from connString with similar behavior to the PostgreSQL standard C library libpq. It
// uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely // uses the same defaults as libpq (e.g. port=5432) and understands most PG* environment variables. ParseConfig closely
// matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format (DSN style). // matches the parsing behavior of libpq. connString may either be in URL format or keyword = value format. See
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be // https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for details. connString also may be empty
// empty to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file. // to only read from the environment. If a password is not supplied it will attempt to read the .pgpass file.
// //
// # Example DSN // # Example Keyword/Value
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca // user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
// //
// # Example URL // # Example URL
@ -165,7 +185,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb // postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
// //
// ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed // ParseConfig currently recognizes the following environment variable and their parameter key word equivalents passed
// via database URL or DSN: // via database URL or keyword/value:
// //
// PGHOST // PGHOST
// PGPORT // PGPORT
@ -180,9 +200,11 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// PGSSLKEY // PGSSLKEY
// PGSSLROOTCERT // PGSSLROOTCERT
// PGSSLPASSWORD // PGSSLPASSWORD
// PGOPTIONS
// PGAPPNAME // PGAPPNAME
// PGCONNECT_TIMEOUT // PGCONNECT_TIMEOUT
// PGTARGETSESSIONATTRS // PGTARGETSESSIONATTRS
// PGTZ
// //
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables. // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
// //
@ -211,7 +233,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// //
// In addition, ParseConfig accepts the following options: // In addition, ParseConfig accepts the following options:
// //
// servicefile // - servicefile.
// libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a // libpq only reads servicefile from the PGSERVICEFILE environment variable. ParseConfig accepts servicefile as a
// part of the connection string. // part of the connection string.
func ParseConfig(connString string) (*Config, error) { func ParseConfig(connString string) (*Config, error) {
@ -229,16 +251,16 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
connStringSettings := make(map[string]string) connStringSettings := make(map[string]string)
if connString != "" { if connString != "" {
var err error var err error
// connString may be a database URL or a DSN // connString may be a database URL or in PostgreSQL keyword/value format
if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") { if strings.HasPrefix(connString, "postgres://") || strings.HasPrefix(connString, "postgresql://") {
connStringSettings, err = parseURLSettings(connString) connStringSettings, err = parseURLSettings(connString)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as URL", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as URL", err: err}
} }
} else { } else {
connStringSettings, err = parseDSNSettings(connString) connStringSettings, err = parseKeywordValueSettings(connString)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to parse as DSN", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "failed to parse as keyword/value", err: err}
} }
} }
} }
@ -247,7 +269,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
if service, present := settings["service"]; present { if service, present := settings["service"]; present {
serviceSettings, err := parseServiceSettings(settings["servicefile"], service) serviceSettings, err := parseServiceSettings(settings["servicefile"], service)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to read service", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "failed to read service", err: err}
} }
settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings) settings = mergeSettings(defaultSettings, envSettings, serviceSettings, connStringSettings)
@ -262,12 +284,22 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend { BuildFrontend: func(r io.Reader, w io.Writer) *pgproto3.Frontend {
return pgproto3.NewFrontend(r, w) return pgproto3.NewFrontend(r, w)
}, },
BuildContextWatcherHandler: func(pgConn *PgConn) ctxwatch.Handler {
return &DeadlineContextWatcherHandler{Conn: pgConn.conn}
},
OnPgError: func(_ *PgConn, pgErr *PgError) bool {
// we want to automatically close any fatal errors
if strings.EqualFold(pgErr.Severity, "FATAL") {
return false
}
return true
},
} }
if connectTimeoutSetting, present := settings["connect_timeout"]; present { if connectTimeoutSetting, present := settings["connect_timeout"]; present {
connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting) connectTimeout, err := parseConnectTimeoutSetting(connectTimeoutSetting)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid connect_timeout", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "invalid connect_timeout", err: err}
} }
config.ConnectTimeout = connectTimeout config.ConnectTimeout = connectTimeout
config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout) config.DialFunc = makeConnectTimeoutDialFunc(connectTimeout)
@ -290,7 +322,9 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
"sslkey": {}, "sslkey": {},
"sslcert": {}, "sslcert": {},
"sslrootcert": {}, "sslrootcert": {},
"sslnegotiation": {},
"sslpassword": {}, "sslpassword": {},
"sslsni": {},
"krbspn": {}, "krbspn": {},
"krbsrvname": {}, "krbsrvname": {},
"target_session_attrs": {}, "target_session_attrs": {},
@ -328,7 +362,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
port, err := parsePort(portStr) port, err := parsePort(portStr)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "invalid port", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "invalid port", err: err}
} }
var tlsConfigs []*tls.Config var tlsConfigs []*tls.Config
@ -340,7 +374,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
var err error var err error
tlsConfigs, err = configTLS(settings, host, options) tlsConfigs, err = configTLS(settings, host, options)
if err != nil { if err != nil {
return nil, &parseConfigError{connString: connString, msg: "failed to configure TLS", err: err} return nil, &ParseConfigError{ConnString: connString, msg: "failed to configure TLS", err: err}
} }
} }
@ -357,6 +391,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
config.Port = fallbacks[0].Port config.Port = fallbacks[0].Port
config.TLSConfig = fallbacks[0].TLSConfig config.TLSConfig = fallbacks[0].TLSConfig
config.Fallbacks = fallbacks[1:] config.Fallbacks = fallbacks[1:]
config.SSLNegotiation = settings["sslnegotiation"]
passfile, err := pgpassfile.ReadPassfile(settings["passfile"]) passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
if err == nil { if err == nil {
@ -384,7 +419,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
case "any": case "any":
// do nothing // do nothing
default: default:
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)} return nil, &ParseConfigError{ConnString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
} }
return config, nil return config, nil
@ -417,11 +452,15 @@ func parseEnvSettings() map[string]string {
"PGSSLMODE": "sslmode", "PGSSLMODE": "sslmode",
"PGSSLKEY": "sslkey", "PGSSLKEY": "sslkey",
"PGSSLCERT": "sslcert", "PGSSLCERT": "sslcert",
"PGSSLSNI": "sslsni",
"PGSSLROOTCERT": "sslrootcert", "PGSSLROOTCERT": "sslrootcert",
"PGSSLPASSWORD": "sslpassword", "PGSSLPASSWORD": "sslpassword",
"PGSSLNEGOTIATION": "sslnegotiation",
"PGTARGETSESSIONATTRS": "target_session_attrs", "PGTARGETSESSIONATTRS": "target_session_attrs",
"PGSERVICE": "service", "PGSERVICE": "service",
"PGSERVICEFILE": "servicefile", "PGSERVICEFILE": "servicefile",
"PGTZ": "timezone",
"PGOPTIONS": "options",
} }
for envname, realname := range nameMap { for envname, realname := range nameMap {
@ -437,14 +476,17 @@ func parseEnvSettings() map[string]string {
func parseURLSettings(connString string) (map[string]string, error) { func parseURLSettings(connString string) (map[string]string, error) {
settings := make(map[string]string) settings := make(map[string]string)
url, err := url.Parse(connString) parsedURL, err := url.Parse(connString)
if err != nil { if err != nil {
if urlErr := new(url.Error); errors.As(err, &urlErr) {
return nil, urlErr.Err
}
return nil, err return nil, err
} }
if url.User != nil { if parsedURL.User != nil {
settings["user"] = url.User.Username() settings["user"] = parsedURL.User.Username()
if password, present := url.User.Password(); present { if password, present := parsedURL.User.Password(); present {
settings["password"] = password settings["password"] = password
} }
} }
@ -452,7 +494,7 @@ func parseURLSettings(connString string) (map[string]string, error) {
// Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port. // Handle multiple host:port's in url.Host by splitting them into host,host,host and port,port,port.
var hosts []string var hosts []string
var ports []string var ports []string
for _, host := range strings.Split(url.Host, ",") { for _, host := range strings.Split(parsedURL.Host, ",") {
if host == "" { if host == "" {
continue continue
} }
@ -478,7 +520,7 @@ func parseURLSettings(connString string) (map[string]string, error) {
settings["port"] = strings.Join(ports, ",") settings["port"] = strings.Join(ports, ",")
} }
database := strings.TrimLeft(url.Path, "/") database := strings.TrimLeft(parsedURL.Path, "/")
if database != "" { if database != "" {
settings["database"] = database settings["database"] = database
} }
@ -487,7 +529,7 @@ func parseURLSettings(connString string) (map[string]string, error) {
"dbname": "database", "dbname": "database",
} }
for k, v := range url.Query() { for k, v := range parsedURL.Query() {
if k2, present := nameMap[k]; present { if k2, present := nameMap[k]; present {
k = k2 k = k2
} }
@ -504,7 +546,7 @@ func isIPOnly(host string) bool {
var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1} var asciiSpace = [256]uint8{'\t': 1, '\n': 1, '\v': 1, '\f': 1, '\r': 1, ' ': 1}
func parseDSNSettings(s string) (map[string]string, error) { func parseKeywordValueSettings(s string) (map[string]string, error) {
settings := make(map[string]string) settings := make(map[string]string)
nameMap := map[string]string{ nameMap := map[string]string{
@ -515,7 +557,7 @@ func parseDSNSettings(s string) (map[string]string, error) {
var key, val string var key, val string
eqIdx := strings.IndexRune(s, '=') eqIdx := strings.IndexRune(s, '=')
if eqIdx < 0 { if eqIdx < 0 {
return nil, errors.New("invalid dsn") return nil, errors.New("invalid keyword/value")
} }
key = strings.Trim(s[:eqIdx], " \t\n\r\v\f") key = strings.Trim(s[:eqIdx], " \t\n\r\v\f")
@ -567,7 +609,7 @@ func parseDSNSettings(s string) (map[string]string, error) {
} }
if key == "" { if key == "" {
return nil, errors.New("invalid dsn") return nil, errors.New("invalid keyword/value")
} }
settings[key] = val settings[key] = val
@ -612,14 +654,56 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
sslcert := settings["sslcert"] sslcert := settings["sslcert"]
sslkey := settings["sslkey"] sslkey := settings["sslkey"]
sslpassword := settings["sslpassword"] sslpassword := settings["sslpassword"]
sslsni := settings["sslsni"]
sslnegotiation := settings["sslnegotiation"]
// Match libpq default behavior // Match libpq default behavior
if sslmode == "" { if sslmode == "" {
sslmode = "prefer" sslmode = "prefer"
} }
if sslsni == "" {
sslsni = "1"
}
tlsConfig := &tls.Config{} tlsConfig := &tls.Config{}
if sslnegotiation == "direct" {
tlsConfig.NextProtos = []string{"postgresql"}
if sslmode == "prefer" {
sslmode = "require"
}
}
if sslrootcert != "" {
var caCertPool *x509.CertPool
if sslrootcert == "system" {
var err error
caCertPool, err = x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("unable to load system certificate pool: %w", err)
}
sslmode = "verify-full"
} else {
caCertPool = x509.NewCertPool()
caPath := sslrootcert
caCert, err := os.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("unable to read CA file: %w", err)
}
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, errors.New("unable to add CA to cert pool")
}
}
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
}
switch sslmode { switch sslmode {
case "disable": case "disable":
return []*tls.Config{nil}, nil return []*tls.Config{nil}, nil
@ -677,33 +761,19 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
return nil, errors.New("sslmode is invalid") return nil, errors.New("sslmode is invalid")
} }
if sslrootcert != "" {
caCertPool := x509.NewCertPool()
caPath := sslrootcert
caCert, err := ioutil.ReadFile(caPath)
if err != nil {
return nil, fmt.Errorf("unable to read CA file: %w", err)
}
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, errors.New("unable to add CA to cert pool")
}
tlsConfig.RootCAs = caCertPool
tlsConfig.ClientCAs = caCertPool
}
if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") { if (sslcert != "" && sslkey == "") || (sslcert == "" && sslkey != "") {
return nil, errors.New(`both "sslcert" and "sslkey" are required`) return nil, errors.New(`both "sslcert" and "sslkey" are required`)
} }
if sslcert != "" && sslkey != "" { if sslcert != "" && sslkey != "" {
buf, err := ioutil.ReadFile(sslkey) buf, err := os.ReadFile(sslkey)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to read sslkey: %w", err) return nil, fmt.Errorf("unable to read sslkey: %w", err)
} }
block, _ := pem.Decode(buf) block, _ := pem.Decode(buf)
if block == nil {
return nil, errors.New("failed to decode sslkey")
}
var pemKey []byte var pemKey []byte
var decryptedKey []byte var decryptedKey []byte
var decryptedError error var decryptedError error
@ -738,7 +808,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
} else { } else {
pemKey = pem.EncodeToMemory(block) pemKey = pem.EncodeToMemory(block)
} }
certfile, err := ioutil.ReadFile(sslcert) certfile, err := os.ReadFile(sslcert)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to read cert: %w", err) return nil, fmt.Errorf("unable to read cert: %w", err)
} }
@ -749,6 +819,13 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
tlsConfig.Certificates = []tls.Certificate{cert} tlsConfig.Certificates = []tls.Certificate{cert}
} }
// Set Server Name Indication (SNI), if enabled by connection parameters.
// Per RFC 6066, do not set it if the host is a literal IP address (IPv4
// or IPv6).
if sslsni == "1" && net.ParseIP(host) == nil {
tlsConfig.ServerName = host
}
switch sslmode { switch sslmode {
case "allow": case "allow":
return []*tls.Config{nil, tlsConfig}, nil return []*tls.Config{nil, tlsConfig}, nil
@ -773,7 +850,8 @@ func parsePort(s string) (uint16, error) {
} }
func makeDefaultDialer() *net.Dialer { func makeDefaultDialer() *net.Dialer {
return &net.Dialer{KeepAlive: 5 * time.Minute} // rely on GOLANG KeepAlive settings
return &net.Dialer{}
} }
func makeDefaultResolver() *net.Resolver { func makeDefaultResolver() *net.Resolver {
@ -797,75 +875,75 @@ func makeConnectTimeoutDialFunc(timeout time.Duration) DialFunc {
return d.DialContext return d.DialContext
} }
// ValidateConnectTargetSessionAttrsReadWrite is an ValidateConnectFunc that implements libpq compatible // ValidateConnectTargetSessionAttrsReadWrite is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-write. // target_session_attrs=read-write.
func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error { func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll()
if result.Err != nil { if err != nil {
return result.Err return err
} }
if string(result.Rows[0][0]) == "on" { if string(result[0].Rows[0][0]) == "on" {
return errors.New("read only connection") return errors.New("read only connection")
} }
return nil return nil
} }
// ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible // ValidateConnectTargetSessionAttrsReadOnly is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=read-only. // target_session_attrs=read-only.
func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error { func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read() result, err := pgConn.Exec(ctx, "show transaction_read_only").ReadAll()
if result.Err != nil { if err != nil {
return result.Err return err
} }
if string(result.Rows[0][0]) != "on" { if string(result[0].Rows[0][0]) != "on" {
return errors.New("connection is not read only") return errors.New("connection is not read only")
} }
return nil return nil
} }
// ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible // ValidateConnectTargetSessionAttrsStandby is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=standby. // target_session_attrs=standby.
func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error { func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll()
if result.Err != nil { if err != nil {
return result.Err return err
} }
if string(result.Rows[0][0]) != "t" { if string(result[0].Rows[0][0]) != "t" {
return errors.New("server is not in hot standby mode") return errors.New("server is not in hot standby mode")
} }
return nil return nil
} }
// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible // ValidateConnectTargetSessionAttrsPrimary is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=primary. // target_session_attrs=primary.
func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error { func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll()
if result.Err != nil { if err != nil {
return result.Err return err
} }
if string(result.Rows[0][0]) == "t" { if string(result[0].Rows[0][0]) == "t" {
return errors.New("server is in standby mode") return errors.New("server is in standby mode")
} }
return nil return nil
} }
// ValidateConnectTargetSessionAttrsPreferStandby is an ValidateConnectFunc that implements libpq compatible // ValidateConnectTargetSessionAttrsPreferStandby is a ValidateConnectFunc that implements libpq compatible
// target_session_attrs=prefer-standby. // target_session_attrs=prefer-standby.
func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error { func ValidateConnectTargetSessionAttrsPreferStandby(ctx context.Context, pgConn *PgConn) error {
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read() result, err := pgConn.Exec(ctx, "select pg_is_in_recovery()").ReadAll()
if result.Err != nil { if err != nil {
return result.Err return err
} }
if string(result.Rows[0][0]) != "t" { if string(result[0].Rows[0][0]) != "t" {
return &NotPreferredError{err: errors.New("server is not in hot standby mode")} return &NotPreferredError{err: errors.New("server is not in hot standby mode")}
} }

View File

@ -4,10 +4,11 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"os/user" "os/user"
"path/filepath"
"runtime" "runtime"
"strconv"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -17,8 +18,25 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestParseConfig(t *testing.T) { func skipOnWindows(t *testing.T) {
t.Parallel() if runtime.GOOS == "windows" {
t.Skip("FIXME: skipping on Windows, investigate why this test fails in CI environment")
}
}
func getDefaultPort(t *testing.T) uint16 {
if envPGPORT := os.Getenv("PGPORT"); envPGPORT != "" {
p, err := strconv.ParseUint(envPGPORT, 10, 16)
require.NoError(t, err)
return uint16(p)
}
return 5432
}
func getDefaultUser(t *testing.T) string {
if pguser := os.Getenv("PGUSER"); pguser != "" {
return pguser
}
var osUserName string var osUserName string
osUser, err := user.Current() osUser, err := user.Current()
@ -32,10 +50,20 @@ func TestParseConfig(t *testing.T) {
} }
} }
return osUserName
}
func TestParseConfig(t *testing.T) {
skipOnWindows(t)
t.Parallel()
config, err := pgconn.ParseConfig("") config, err := pgconn.ParseConfig("")
require.NoError(t, err) require.NoError(t, err)
defaultHost := config.Host defaultHost := config.Host
defaultUser := getDefaultUser(t)
defaultPort := getDefaultPort(t)
tests := []struct { tests := []struct {
name string name string
connString string connString string
@ -53,10 +81,11 @@ func TestParseConfig(t *testing.T) {
Database: "mydb", Database: "mydb",
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: "localhost",
}, },
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{ Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{ {
Host: "localhost", Host: "localhost",
Port: 5432, Port: 5432,
TLSConfig: nil, TLSConfig: nil,
@ -89,11 +118,12 @@ func TestParseConfig(t *testing.T) {
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{ Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{ {
Host: "localhost", Host: "localhost",
Port: 5432, Port: 5432,
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: "localhost",
}, },
}, },
}, },
@ -111,10 +141,11 @@ func TestParseConfig(t *testing.T) {
Database: "mydb", Database: "mydb",
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: "localhost",
}, },
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{ Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{ {
Host: "localhost", Host: "localhost",
Port: 5432, Port: 5432,
TLSConfig: nil, TLSConfig: nil,
@ -133,6 +164,7 @@ func TestParseConfig(t *testing.T) {
Database: "mydb", Database: "mydb",
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: "localhost",
}, },
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
}, },
@ -148,6 +180,7 @@ func TestParseConfig(t *testing.T) {
Database: "mydb", Database: "mydb",
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: "localhost",
}, },
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
}, },
@ -198,7 +231,7 @@ func TestParseConfig(t *testing.T) {
name: "database url missing user and password", name: "database url missing user and password",
connString: "postgres://localhost:5432/mydb?sslmode=disable", connString: "postgres://localhost:5432/mydb?sslmode=disable",
config: &pgconn.Config{ config: &pgconn.Config{
User: osUserName, User: defaultUser,
Host: "localhost", Host: "localhost",
Port: 5432, Port: 5432,
Database: "mydb", Database: "mydb",
@ -223,9 +256,9 @@ func TestParseConfig(t *testing.T) {
name: "database url unix domain socket host", name: "database url unix domain socket host",
connString: "postgres:///foo?host=/tmp", connString: "postgres:///foo?host=/tmp",
config: &pgconn.Config{ config: &pgconn.Config{
User: osUserName, User: defaultUser,
Host: "/tmp", Host: "/tmp",
Port: 5432, Port: defaultPort,
Database: "foo", Database: "foo",
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
@ -235,9 +268,9 @@ func TestParseConfig(t *testing.T) {
name: "database url unix domain socket host on windows", name: "database url unix domain socket host on windows",
connString: "postgres:///foo?host=C:\\tmp", connString: "postgres:///foo?host=C:\\tmp",
config: &pgconn.Config{ config: &pgconn.Config{
User: osUserName, User: defaultUser,
Host: "C:\\tmp", Host: "C:\\tmp",
Port: 5432, Port: defaultPort,
Database: "foo", Database: "foo",
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
@ -247,9 +280,9 @@ func TestParseConfig(t *testing.T) {
name: "database url dbname", name: "database url dbname",
connString: "postgres://localhost/?dbname=foo&sslmode=disable", connString: "postgres://localhost/?dbname=foo&sslmode=disable",
config: &pgconn.Config{ config: &pgconn.Config{
User: osUserName, User: defaultUser,
Host: "localhost", Host: "localhost",
Port: 5432, Port: defaultPort,
Database: "foo", Database: "foo",
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
@ -297,14 +330,14 @@ func TestParseConfig(t *testing.T) {
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
Host: "2001:db8::1", Host: "2001:db8::1",
Port: 5432, Port: defaultPort,
Database: "mydb", Database: "mydb",
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
}, },
}, },
{ {
name: "DSN everything", name: "Key/value everything",
connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5", connString: "user=jack password=secret host=localhost port=5432 dbname=mydb sslmode=disable application_name=pgxtest search_path=myschema connect_timeout=5",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
@ -321,7 +354,7 @@ func TestParseConfig(t *testing.T) {
}, },
}, },
{ {
name: "DSN with escaped single quote", name: "Key/value with escaped single quote",
connString: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb sslmode=disable", connString: "user=jack\\'s password=secret host=localhost port=5432 dbname=mydb sslmode=disable",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack's", User: "jack's",
@ -334,7 +367,7 @@ func TestParseConfig(t *testing.T) {
}, },
}, },
{ {
name: "DSN with escaped backslash", name: "Key/value with escaped backslash",
connString: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb sslmode=disable", connString: "user=jack password=sooper\\\\secret host=localhost port=5432 dbname=mydb sslmode=disable",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
@ -347,48 +380,48 @@ func TestParseConfig(t *testing.T) {
}, },
}, },
{ {
name: "DSN with single quoted values", name: "Key/value with single quoted values",
connString: "user='jack' host='localhost' dbname='mydb' sslmode='disable'", connString: "user='jack' host='localhost' dbname='mydb' sslmode='disable'",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
Host: "localhost", Host: "localhost",
Port: 5432, Port: defaultPort,
Database: "mydb", Database: "mydb",
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
}, },
}, },
{ {
name: "DSN with single quoted value with escaped single quote", name: "Key/value with single quoted value with escaped single quote",
connString: "user='jack\\'s' host='localhost' dbname='mydb' sslmode='disable'", connString: "user='jack\\'s' host='localhost' dbname='mydb' sslmode='disable'",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack's", User: "jack's",
Host: "localhost", Host: "localhost",
Port: 5432, Port: defaultPort,
Database: "mydb", Database: "mydb",
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
}, },
}, },
{ {
name: "DSN with empty single quoted value", name: "Key/value with empty single quoted value",
connString: "user='jack' password='' host='localhost' dbname='mydb' sslmode='disable'", connString: "user='jack' password='' host='localhost' dbname='mydb' sslmode='disable'",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
Host: "localhost", Host: "localhost",
Port: 5432, Port: defaultPort,
Database: "mydb", Database: "mydb",
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
}, },
}, },
{ {
name: "DSN with space between key and value", name: "Key/value with space between key and value",
connString: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb' sslmode='disable'", connString: "user = 'jack' password = '' host = 'localhost' dbname = 'mydb' sslmode='disable'",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
Host: "localhost", Host: "localhost",
Port: 5432, Port: defaultPort,
Database: "mydb", Database: "mydb",
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
@ -401,19 +434,19 @@ func TestParseConfig(t *testing.T) {
User: "jack", User: "jack",
Password: "secret", Password: "secret",
Host: "foo", Host: "foo",
Port: 5432, Port: defaultPort,
Database: "mydb", Database: "mydb",
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{ Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{ {
Host: "bar", Host: "bar",
Port: 5432, Port: defaultPort,
TLSConfig: nil, TLSConfig: nil,
}, },
&pgconn.FallbackConfig{ {
Host: "baz", Host: "baz",
Port: 5432, Port: defaultPort,
TLSConfig: nil, TLSConfig: nil,
}, },
}, },
@ -431,12 +464,12 @@ func TestParseConfig(t *testing.T) {
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{ Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{ {
Host: "bar", Host: "bar",
Port: 2, Port: 2,
TLSConfig: nil, TLSConfig: nil,
}, },
&pgconn.FallbackConfig{ {
Host: "baz", Host: "baz",
Port: 3, Port: 3,
TLSConfig: nil, TLSConfig: nil,
@ -459,7 +492,7 @@ func TestParseConfig(t *testing.T) {
}, },
}, },
{ {
name: "DSN multiple hosts one port", name: "Key/value multiple hosts one port",
connString: "user=jack password=secret host=foo,bar,baz port=5432 dbname=mydb sslmode=disable", connString: "user=jack password=secret host=foo,bar,baz port=5432 dbname=mydb sslmode=disable",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
@ -470,12 +503,12 @@ func TestParseConfig(t *testing.T) {
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{ Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{ {
Host: "bar", Host: "bar",
Port: 5432, Port: 5432,
TLSConfig: nil, TLSConfig: nil,
}, },
&pgconn.FallbackConfig{ {
Host: "baz", Host: "baz",
Port: 5432, Port: 5432,
TLSConfig: nil, TLSConfig: nil,
@ -484,7 +517,7 @@ func TestParseConfig(t *testing.T) {
}, },
}, },
{ {
name: "DSN multiple hosts multiple ports", name: "Key/value multiple hosts multiple ports",
connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 dbname=mydb sslmode=disable", connString: "user=jack password=secret host=foo,bar,baz port=1,2,3 dbname=mydb sslmode=disable",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
@ -495,12 +528,12 @@ func TestParseConfig(t *testing.T) {
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{ Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{ {
Host: "bar", Host: "bar",
Port: 2, Port: 2,
TLSConfig: nil, TLSConfig: nil,
}, },
&pgconn.FallbackConfig{ {
Host: "baz", Host: "baz",
Port: 3, Port: 3,
TLSConfig: nil, TLSConfig: nil,
@ -509,44 +542,47 @@ func TestParseConfig(t *testing.T) {
}, },
}, },
{ {
name: "multiple hosts and fallback tsl", name: "multiple hosts and fallback tls",
connString: "user=jack password=secret host=foo,bar,baz dbname=mydb sslmode=prefer", connString: "user=jack password=secret host=foo,bar,baz dbname=mydb sslmode=prefer",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
Password: "secret", Password: "secret",
Host: "foo", Host: "foo",
Port: 5432, Port: defaultPort,
Database: "mydb", Database: "mydb",
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: "foo",
}, },
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{ Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{ {
Host: "foo", Host: "foo",
Port: 5432, Port: defaultPort,
TLSConfig: nil, TLSConfig: nil,
}, },
&pgconn.FallbackConfig{ {
Host: "bar", Host: "bar",
Port: 5432, Port: defaultPort,
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: "bar",
}}, }},
&pgconn.FallbackConfig{ {
Host: "bar", Host: "bar",
Port: 5432, Port: defaultPort,
TLSConfig: nil, TLSConfig: nil,
}, },
&pgconn.FallbackConfig{ {
Host: "baz", Host: "baz",
Port: 5432, Port: defaultPort,
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: "baz",
}}, }},
&pgconn.FallbackConfig{ {
Host: "baz", Host: "baz",
Port: 5432, Port: defaultPort,
TLSConfig: nil, TLSConfig: nil,
}, },
}, },
@ -648,6 +684,82 @@ func TestParseConfig(t *testing.T) {
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
}, },
}, },
{
name: "SNI is set by default",
connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "sni.test",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
ServerName: "sni.test",
},
RuntimeParams: map[string]string{},
},
},
{
name: "SNI is not set for IPv4",
connString: "postgres://jack:secret@1.1.1.1:5432/mydb?sslmode=require",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "1.1.1.1",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
},
},
{
name: "SNI is not set for IPv6",
connString: "postgres://jack:secret@[::1]:5432/mydb?sslmode=require",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "::1",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
},
},
{
name: "SNI is not set when disabled (URL-style)",
connString: "postgres://jack:secret@sni.test:5432/mydb?sslmode=require&sslsni=0",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "sni.test",
Port: 5432,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
},
},
{
name: "SNI is not set when disabled (key/value style)",
connString: "user=jack password=secret host=sni.test dbname=mydb sslmode=require sslsni=0",
config: &pgconn.Config{
User: "jack",
Password: "secret",
Host: "sni.test",
Port: defaultPort,
Database: "mydb",
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
},
},
} }
for i, tt := range tests { for i, tt := range tests {
@ -661,18 +773,18 @@ func TestParseConfig(t *testing.T) {
} }
// https://github.com/jackc/pgconn/issues/47 // https://github.com/jackc/pgconn/issues/47
func TestParseConfigDSNWithTrailingEmptyEqualDoesNotPanic(t *testing.T) { func TestParseConfigKVWithTrailingEmptyEqualDoesNotPanic(t *testing.T) {
_, err := pgconn.ParseConfig("host= user= password= port= database=") _, err := pgconn.ParseConfig("host= user= password= port= database=")
require.NoError(t, err) require.NoError(t, err)
} }
func TestParseConfigDSNLeadingEqual(t *testing.T) { func TestParseConfigKVLeadingEqual(t *testing.T) {
_, err := pgconn.ParseConfig("= user=jack") _, err := pgconn.ParseConfig("= user=jack")
require.Error(t, err) require.Error(t, err)
} }
// https://github.com/jackc/pgconn/issues/49 // https://github.com/jackc/pgconn/issues/49
func TestParseConfigDSNTrailingBackslash(t *testing.T) { func TestParseConfigKVTrailingBackslash(t *testing.T) {
_, err := pgconn.ParseConfig(`x=x\`) _, err := pgconn.ParseConfig(`x=x\`)
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "invalid backslash") assert.Contains(t, err.Error(), "invalid backslash")
@ -705,7 +817,7 @@ func TestConfigCopyOriginalConfigDidNotChange(t *testing.T) {
} }
func TestConfigCopyCanBeUsedToConnect(t *testing.T) { func TestConfigCopyCanBeUsedToConnect(t *testing.T) {
connString := os.Getenv("PGX_TEST_CONN_STRING") connString := os.Getenv("PGX_TEST_DATABASE")
original, err := pgconn.ParseConfig(connString) original, err := pgconn.ParseConfig(connString)
require.NoError(t, err) require.NoError(t, err)
@ -820,20 +932,7 @@ func TestParseConfigEnvLibpq(t *testing.T) {
} }
} }
pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT"} pgEnvvars := []string{"PGHOST", "PGPORT", "PGDATABASE", "PGUSER", "PGPASSWORD", "PGAPPNAME", "PGSSLMODE", "PGCONNECT_TIMEOUT", "PGSSLSNI", "PGTZ", "PGOPTIONS"}
savedEnv := make(map[string]string)
for _, n := range pgEnvvars {
savedEnv[n] = os.Getenv(n)
}
defer func() {
for k, v := range savedEnv {
err := os.Setenv(k, v)
if err != nil {
t.Fatalf("Unable to restore environment: %v", err)
}
}
}()
tests := []struct { tests := []struct {
name string name string
@ -853,7 +952,7 @@ func TestParseConfigEnvLibpq(t *testing.T) {
}, },
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{ Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{ {
Host: "123.123.123.123", Host: "123.123.123.123",
Port: 5432, Port: 5432,
TLSConfig: nil, TLSConfig: nil,
@ -872,6 +971,8 @@ func TestParseConfigEnvLibpq(t *testing.T) {
"PGCONNECT_TIMEOUT": "10", "PGCONNECT_TIMEOUT": "10",
"PGSSLMODE": "disable", "PGSSLMODE": "disable",
"PGAPPNAME": "pgxtest", "PGAPPNAME": "pgxtest",
"PGTZ": "America/New_York",
"PGOPTIONS": "-c search_path=myschema",
}, },
config: &pgconn.Config{ config: &pgconn.Config{
Host: "123.123.123.123", Host: "123.123.123.123",
@ -881,20 +982,31 @@ func TestParseConfigEnvLibpq(t *testing.T) {
Password: "baz", Password: "baz",
ConnectTimeout: 10 * time.Second, ConnectTimeout: 10 * time.Second,
TLSConfig: nil, TLSConfig: nil,
RuntimeParams: map[string]string{"application_name": "pgxtest"}, RuntimeParams: map[string]string{"application_name": "pgxtest", "timezone": "America/New_York", "options": "-c search_path=myschema"},
},
},
{
name: "SNI can be disabled via environment variable",
envvars: map[string]string{
"PGHOST": "test.foo",
"PGSSLMODE": "require",
"PGSSLSNI": "0",
},
config: &pgconn.Config{
User: osUserName,
Host: "test.foo",
Port: 5432,
TLSConfig: &tls.Config{
InsecureSkipVerify: true,
},
RuntimeParams: map[string]string{},
}, },
}, },
} }
for i, tt := range tests { for i, tt := range tests {
for _, n := range pgEnvvars { for _, env := range pgEnvvars {
err := os.Unsetenv(n) t.Setenv(env, tt.envvars[env])
require.NoError(t, err)
}
for k, v := range tt.envvars {
err := os.Setenv(k, v)
require.NoError(t, err)
} }
config, err := pgconn.ParseConfig("") config, err := pgconn.ParseConfig("")
@ -907,18 +1019,14 @@ func TestParseConfigEnvLibpq(t *testing.T) {
} }
func TestParseConfigReadsPgPassfile(t *testing.T) { func TestParseConfigReadsPgPassfile(t *testing.T) {
skipOnWindows(t)
t.Parallel() t.Parallel()
tf, err := ioutil.TempFile("", "") tfName := filepath.Join(t.TempDir(), "config")
err := os.WriteFile(tfName, []byte("test1:5432:curlydb:curly:nyuknyuknyuk"), 0600)
require.NoError(t, err) require.NoError(t, err)
defer tf.Close() connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tfName)
defer os.Remove(tf.Name())
_, err = tf.Write([]byte("test1:5432:curlydb:curly:nyuknyuknyuk"))
require.NoError(t, err)
connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tf.Name())
expected := &pgconn.Config{ expected := &pgconn.Config{
User: "curly", User: "curly",
Password: "nyuknyuknyuk", Password: "nyuknyuknyuk",
@ -936,15 +1044,12 @@ func TestParseConfigReadsPgPassfile(t *testing.T) {
} }
func TestParseConfigReadsPgServiceFile(t *testing.T) { func TestParseConfigReadsPgServiceFile(t *testing.T) {
skipOnWindows(t)
t.Parallel() t.Parallel()
tf, err := ioutil.TempFile("", "") tfName := filepath.Join(t.TempDir(), "config")
require.NoError(t, err)
defer tf.Close() err := os.WriteFile(tfName, []byte(`
defer os.Remove(tf.Name())
_, err = tf.Write([]byte(`
[abc] [abc]
host=abc.example.com host=abc.example.com
port=9999 port=9999
@ -956,9 +1061,11 @@ host = def.example.com
dbname = defdb dbname = defdb
user = defuser user = defuser
application_name = spaced string application_name = spaced string
`)) `), 0600)
require.NoError(t, err) require.NoError(t, err)
defaultPort := getDefaultPort(t)
tests := []struct { tests := []struct {
name string name string
connString string connString string
@ -966,7 +1073,7 @@ application_name = spaced string
}{ }{
{ {
name: "abc", name: "abc",
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "abc"), connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tfName, "abc"),
config: &pgconn.Config{ config: &pgconn.Config{
Host: "abc.example.com", Host: "abc.example.com",
Database: "abcdb", Database: "abcdb",
@ -974,10 +1081,11 @@ application_name = spaced string
Port: 9999, Port: 9999,
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: "abc.example.com",
}, },
RuntimeParams: map[string]string{}, RuntimeParams: map[string]string{},
Fallbacks: []*pgconn.FallbackConfig{ Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{ {
Host: "abc.example.com", Host: "abc.example.com",
Port: 9999, Port: 9999,
TLSConfig: nil, TLSConfig: nil,
@ -987,20 +1095,21 @@ application_name = spaced string
}, },
{ {
name: "def", name: "def",
connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tf.Name(), "def"), connString: fmt.Sprintf("postgres:///?servicefile=%s&service=%s", tfName, "def"),
config: &pgconn.Config{ config: &pgconn.Config{
Host: "def.example.com", Host: "def.example.com",
Port: 5432, Port: defaultPort,
Database: "defdb", Database: "defdb",
User: "defuser", User: "defuser",
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: "def.example.com",
}, },
RuntimeParams: map[string]string{"application_name": "spaced string"}, RuntimeParams: map[string]string{"application_name": "spaced string"},
Fallbacks: []*pgconn.FallbackConfig{ Fallbacks: []*pgconn.FallbackConfig{
&pgconn.FallbackConfig{ {
Host: "def.example.com", Host: "def.example.com",
Port: 5432, Port: defaultPort,
TLSConfig: nil, TLSConfig: nil,
}, },
}, },
@ -1008,7 +1117,7 @@ application_name = spaced string
}, },
{ {
name: "conn string has precedence", name: "conn string has precedence",
connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tf.Name(), "abc"), connString: fmt.Sprintf("postgres://other.example.com:7777/?servicefile=%s&service=%s&sslmode=disable", tfName, "abc"),
config: &pgconn.Config{ config: &pgconn.Config{
Host: "other.example.com", Host: "other.example.com",
Database: "abcdb", Database: "abcdb",

View File

@ -8,8 +8,7 @@ import (
// ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a // ContextWatcher watches a context and performs an action when the context is canceled. It can watch one context at a
// time. // time.
type ContextWatcher struct { type ContextWatcher struct {
onCancel func() handler Handler
onUnwatchAfterCancel func()
unwatchChan chan struct{} unwatchChan chan struct{}
lock sync.Mutex lock sync.Mutex
@ -20,10 +19,9 @@ type ContextWatcher struct {
// NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled. // NewContextWatcher returns a ContextWatcher. onCancel will be called when a watched context is canceled.
// OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and // OnUnwatchAfterCancel will be called when Unwatch is called and the watched context had already been canceled and
// onCancel called. // onCancel called.
func NewContextWatcher(onCancel func(), onUnwatchAfterCancel func()) *ContextWatcher { func NewContextWatcher(handler Handler) *ContextWatcher {
cw := &ContextWatcher{ cw := &ContextWatcher{
onCancel: onCancel, handler: handler,
onUnwatchAfterCancel: onUnwatchAfterCancel,
unwatchChan: make(chan struct{}), unwatchChan: make(chan struct{}),
} }
@ -46,7 +44,7 @@ func (cw *ContextWatcher) Watch(ctx context.Context) {
go func() { go func() {
select { select {
case <-ctx.Done(): case <-ctx.Done():
cw.onCancel() cw.handler.HandleCancel(ctx)
cw.onCancelWasCalled = true cw.onCancelWasCalled = true
<-cw.unwatchChan <-cw.unwatchChan
case <-cw.unwatchChan: case <-cw.unwatchChan:
@ -66,8 +64,17 @@ func (cw *ContextWatcher) Unwatch() {
if cw.watchInProgress { if cw.watchInProgress {
cw.unwatchChan <- struct{}{} cw.unwatchChan <- struct{}{}
if cw.onCancelWasCalled { if cw.onCancelWasCalled {
cw.onUnwatchAfterCancel() cw.handler.HandleUnwatchAfterCancel()
} }
cw.watchInProgress = false cw.watchInProgress = false
} }
} }
type Handler interface {
// HandleCancel is called when the context that a ContextWatcher is currently watching is canceled. canceledCtx is the
// context that was canceled.
HandleCancel(canceledCtx context.Context)
// HandleUnwatchAfterCancel is called when a ContextWatcher that called HandleCancel on this Handler is unwatched.
HandleUnwatchAfterCancel()
}

View File

@ -6,17 +6,32 @@ import (
"testing" "testing"
"time" "time"
"github.com/jackc/pgx/v5/pgconn/internal/ctxwatch" "github.com/jackc/pgx/v5/pgconn/ctxwatch"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type testHandler struct {
handleCancel func(context.Context)
handleUnwatchAfterCancel func()
}
func (h *testHandler) HandleCancel(ctx context.Context) {
h.handleCancel(ctx)
}
func (h *testHandler) HandleUnwatchAfterCancel() {
h.handleUnwatchAfterCancel()
}
func TestContextWatcherContextCancelled(t *testing.T) { func TestContextWatcherContextCancelled(t *testing.T) {
canceledChan := make(chan struct{}) canceledChan := make(chan struct{})
cleanupCalled := false cleanupCalled := false
cw := ctxwatch.NewContextWatcher(func() { cw := ctxwatch.NewContextWatcher(&testHandler{
handleCancel: func(context.Context) {
canceledChan <- struct{}{} canceledChan <- struct{}{}
}, func() { }, handleUnwatchAfterCancel: func() {
cleanupCalled = true cleanupCalled = true
},
}) })
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -34,11 +49,13 @@ func TestContextWatcherContextCancelled(t *testing.T) {
require.True(t, cleanupCalled, "Cleanup func was not called") require.True(t, cleanupCalled, "Cleanup func was not called")
} }
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) { func TestContextWatcherUnwatchedBeforeContextCancelled(t *testing.T) {
cw := ctxwatch.NewContextWatcher(func() { cw := ctxwatch.NewContextWatcher(&testHandler{
handleCancel: func(context.Context) {
t.Error("cancel func should not have been called") t.Error("cancel func should not have been called")
}, func() { }, handleUnwatchAfterCancel: func() {
t.Error("cleanup func should not have been called") t.Error("cleanup func should not have been called")
},
}) })
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -48,11 +65,12 @@ func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
} }
func TestContextWatcherMultipleWatchPanics(t *testing.T) { func TestContextWatcherMultipleWatchPanics(t *testing.T) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {}) cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
cw.Watch(ctx) cw.Watch(ctx)
defer cw.Unwatch()
ctx2, cancel2 := context.WithCancel(context.Background()) ctx2, cancel2 := context.WithCancel(context.Background())
defer cancel2() defer cancel2()
@ -60,7 +78,7 @@ func TestContextWatcherMultipleWatchPanics(t *testing.T) {
} }
func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) { func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {}) cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
cw.Unwatch() // unwatch when not / never watching cw.Unwatch() // unwatch when not / never watching
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -71,7 +89,7 @@ func TestContextWatcherUnwatchWhenNotWatchingIsSafe(t *testing.T) {
} }
func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) { func TestContextWatcherUnwatchIsConcurrencySafe(t *testing.T) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {}) cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel() defer cancel()
@ -87,10 +105,12 @@ func TestContextWatcherStress(t *testing.T) {
var cancelFuncCalls int64 var cancelFuncCalls int64
var cleanupFuncCalls int64 var cleanupFuncCalls int64
cw := ctxwatch.NewContextWatcher(func() { cw := ctxwatch.NewContextWatcher(&testHandler{
handleCancel: func(context.Context) {
atomic.AddInt64(&cancelFuncCalls, 1) atomic.AddInt64(&cancelFuncCalls, 1)
}, func() { }, handleUnwatchAfterCancel: func() {
atomic.AddInt64(&cleanupFuncCalls, 1) atomic.AddInt64(&cleanupFuncCalls, 1)
},
}) })
cycleCount := 100000 cycleCount := 100000
@ -103,7 +123,9 @@ func TestContextWatcherStress(t *testing.T) {
} }
// Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix. // Without time.Sleep, cw.Unwatch will almost always run before the cancel func which means cancel will never happen. This gives us a better mix.
if i%3 == 0 { if i%333 == 0 {
// on Windows Sleep takes more time than expected so we try to get here less frequently to avoid
// the CI takes a long time
time.Sleep(time.Nanosecond) time.Sleep(time.Nanosecond)
} }
@ -131,7 +153,7 @@ func TestContextWatcherStress(t *testing.T) {
} }
func BenchmarkContextWatcherUncancellable(b *testing.B) { func BenchmarkContextWatcherUncancellable(b *testing.B) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {}) cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
cw.Watch(context.Background()) cw.Watch(context.Background())
@ -140,7 +162,7 @@ func BenchmarkContextWatcherUncancellable(b *testing.B) {
} }
func BenchmarkContextWatcherCancelled(b *testing.B) { func BenchmarkContextWatcherCancelled(b *testing.B) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {}) cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -151,7 +173,7 @@ func BenchmarkContextWatcherCancelled(b *testing.B) {
} }
func BenchmarkContextWatcherCancellable(b *testing.B) { func BenchmarkContextWatcherCancellable(b *testing.B) {
cw := ctxwatch.NewContextWatcher(func() {}, func() {}) cw := ctxwatch.NewContextWatcher(&testHandler{handleCancel: func(context.Context) {}, handleUnwatchAfterCancel: func() {}})
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()

View File

@ -5,8 +5,8 @@ nearly the same level is the C library libpq.
Establishing a Connection Establishing a Connection
Use Connect to establish a connection. It accepts a connection string in URL or DSN and will read the environment for Use Connect to establish a connection. It accepts a connection string in URL or keyword/value format and will read the
libpq style environment variables. environment for libpq style environment variables.
Executing a Query Executing a Query
@ -20,13 +20,17 @@ result. The ReadAll method reads all query results into memory.
Pipeline Mode Pipeline Mode
Pipeline mode allows sending queries without having read the results of previously sent queries. It allows Pipeline mode allows sending queries without having read the results of previously sent queries. It allows control of
control of exactly how many and when network round trips occur. exactly how many and when network round trips occur.
Context Support Context Support
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the All potentially blocking operations take a context.Context. The default behavior when a context is canceled is for the
method immediately returns. In most circumstances, this will close the underlying connection. method to immediately return. In most circumstances, this will also close the underlying connection. This behavior can
be customized by using BuildContextWatcherHandler on the Config to create a ctxwatch.Handler with different behavior.
This can be especially useful when queries that are frequently canceled and the overhead of creating new connections is
a problem. DeadlineContextWatcherHandler and CancelRequestContextWatcherHandler can be used to introduce a delay before
interrupting the query in such a way as to close the connection.
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
client to abort. client to abort.

View File

@ -12,14 +12,15 @@ import (
// SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server. // SafeToRetry checks if the err is guaranteed to have occurred before sending any data to the server.
func SafeToRetry(err error) bool { func SafeToRetry(err error) bool {
if e, ok := err.(interface{ SafeToRetry() bool }); ok { var retryableErr interface{ SafeToRetry() bool }
return e.SafeToRetry() if errors.As(err, &retryableErr) {
return retryableErr.SafeToRetry()
} }
return false return false
} }
// Timeout checks if err was was caused by a timeout. To be specific, it is true if err was caused within pgconn by a // Timeout checks if err was caused by a timeout. To be specific, it is true if err was caused within pgconn by a
// context.Canceled, context.DeadlineExceeded or an implementer of net.Error where Timeout() is true. // context.DeadlineExceeded or an implementer of net.Error where Timeout() is true.
func Timeout(err error) bool { func Timeout(err error) bool {
var timeoutErr *errTimeout var timeoutErr *errTimeout
return errors.As(err, &timeoutErr) return errors.As(err, &timeoutErr)
@ -30,6 +31,7 @@ func Timeout(err error) bool {
// detailed field description. // detailed field description.
type PgError struct { type PgError struct {
Severity string Severity string
SeverityUnlocalized string
Code string Code string
Message string Message string
Detail string Detail string
@ -57,22 +59,37 @@ func (pe *PgError) SQLState() string {
return pe.Code return pe.Code
} }
type connectError struct { // ConnectError is the error returned when a connection attempt fails.
config *Config type ConnectError struct {
msg string Config *Config // The configuration that was used in the connection attempt.
err error err error
} }
func (e *connectError) Error() string { func (e *ConnectError) Error() string {
sb := &strings.Builder{} prefix := fmt.Sprintf("failed to connect to `user=%s database=%s`:", e.Config.User, e.Config.Database)
fmt.Fprintf(sb, "failed to connect to `host=%s user=%s database=%s`: %s", e.config.Host, e.config.User, e.config.Database, e.msg) details := e.err.Error()
if e.err != nil { if strings.Contains(details, "\n") {
fmt.Fprintf(sb, " (%s)", e.err.Error()) return prefix + "\n\t" + strings.ReplaceAll(details, "\n", "\n\t")
} else {
return prefix + " " + details
} }
return sb.String()
} }
func (e *connectError) Unwrap() error { func (e *ConnectError) Unwrap() error {
return e.err
}
type perDialConnectError struct {
address string
originalHostname string
err error
}
func (e *perDialConnectError) Error() string {
return fmt.Sprintf("%s (%s): %s", e.address, e.originalHostname, e.err.Error())
}
func (e *perDialConnectError) Unwrap() error {
return e.err return e.err
} }
@ -88,29 +105,47 @@ func (e *connLockError) Error() string {
return e.status return e.status
} }
type parseConfigError struct { // ParseConfigError is the error returned when a connection string cannot be parsed.
connString string type ParseConfigError struct {
ConnString string // The connection string that could not be parsed.
msg string msg string
err error err error
} }
func (e *parseConfigError) Error() string { func NewParseConfigError(conn, msg string, err error) error {
connString := redactPW(e.connString) return &ParseConfigError{
ConnString: conn,
msg: msg,
err: err,
}
}
func (e *ParseConfigError) Error() string {
// Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only
// return a static string. That would ensure that the error message cannot leak a password. The ConnString field would
// allow access to the original string if desired and Unwrap would allow access to the underlying error.
connString := redactPW(e.ConnString)
if e.err == nil { if e.err == nil {
return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg) return fmt.Sprintf("cannot parse `%s`: %s", connString, e.msg)
} }
return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error()) return fmt.Sprintf("cannot parse `%s`: %s (%s)", connString, e.msg, e.err.Error())
} }
func (e *parseConfigError) Unwrap() error { func (e *ParseConfigError) Unwrap() error {
return e.err return e.err
} }
// preferContextOverNetTimeoutError returns ctx.Err() if ctx.Err() is present and err is a net.Error with Timeout() == func normalizeTimeoutError(ctx context.Context, err error) error {
// true. Otherwise returns err. var netErr net.Error
func preferContextOverNetTimeoutError(ctx context.Context, err error) error { if errors.As(err, &netErr) && netErr.Timeout() {
if err, ok := err.(net.Error); ok && err.Timeout() && ctx.Err() != nil { if ctx.Err() == context.Canceled {
// Since the timeout was caused by a context cancellation, the actual error is context.Canceled not the timeout error.
return context.Canceled
} else if ctx.Err() == context.DeadlineExceeded {
return &errTimeout{err: ctx.Err()} return &errTimeout{err: ctx.Err()}
} else {
return &errTimeout{err: netErr}
}
} }
return err return err
} }
@ -184,10 +219,10 @@ func redactPW(connString string) string {
return redactURL(u) return redactURL(u)
} }
} }
quotedDSN := regexp.MustCompile(`password='[^']*'`) quotedKV := regexp.MustCompile(`password='[^']*'`)
connString = quotedDSN.ReplaceAllLiteralString(connString, "password=xxxxx") connString = quotedKV.ReplaceAllLiteralString(connString, "password=xxxxx")
plainDSN := regexp.MustCompile(`password=[^ ]*`) plainKV := regexp.MustCompile(`password=[^ ]*`)
connString = plainDSN.ReplaceAllLiteralString(connString, "password=xxxxx") connString = plainKV.ReplaceAllLiteralString(connString, "password=xxxxx")
brokenURL := regexp.MustCompile(`:[^:@]+?@`) brokenURL := regexp.MustCompile(`:[^:@]+?@`)
connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@") connString = brokenURL.ReplaceAllLiteralString(connString, ":xxxxxx@")
return connString return connString

View File

@ -19,18 +19,18 @@ func TestConfigError(t *testing.T) {
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg", expectedMsg: "cannot parse `postgresql://foo:xxxxx@host`: msg",
}, },
{ {
name: "dsn with password unquoted", name: "keyword/value with password unquoted",
err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil), err: pgconn.NewParseConfigError("host=host password=password user=user", "msg", nil),
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
}, },
{ {
name: "dsn with password quoted", name: "keyword/value with password quoted",
err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil), err: pgconn.NewParseConfigError("host=host password='pass word' user=user", "msg", nil),
expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg", expectedMsg: "cannot parse `host=host password=xxxxx user=user`: msg",
}, },
{ {
name: "weird url", name: "weird url",
err: pgconn.NewParseConfigError("postgresql://foo::pasword@host:1:", "msg", nil), err: pgconn.NewParseConfigError("postgresql://foo::password@host:1:", "msg", nil),
expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg", expectedMsg: "cannot parse `postgresql://foo:xxxxx@host:1:`: msg",
}, },
{ {

View File

@ -1,11 +1,3 @@
// File export_test exports some methods for better testing. // File export_test exports some methods for better testing.
package pgconn package pgconn
func NewParseConfigError(conn, msg string, err error) error {
return &parseConfigError{
connString: conn,
msg: msg,
err: err,
}
}

View File

@ -12,19 +12,19 @@ import (
) )
func closeConn(t testing.TB, conn *pgconn.PgConn) { func closeConn(t testing.TB, conn *pgconn.PgConn) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel() defer cancel()
require.NoError(t, conn.Close(ctx)) require.NoError(t, conn.Close(ctx))
select { select {
case <-conn.CleanupDone(): case <-conn.CleanupDone():
case <-time.After(5 * time.Second): case <-time.After(30 * time.Second):
t.Fatal("Connection cleanup exceeded maximum time") t.Fatal("Connection cleanup exceeded maximum time")
} }
} }
// Do a simple query to ensure the connection is still usable // Do a simple query to ensure the connection is still usable
func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) { func ensureConnValid(t *testing.T, pgConn *pgconn.PgConn) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read() result := pgConn.ExecParams(ctx, "select generate_series(1,$1)", [][]byte{[]byte("3")}, nil, nil, nil).Read()
cancel() cancel()

View File

@ -0,0 +1,139 @@
// Package bgreader provides a io.Reader that can optionally buffer reads in the background.
package bgreader
import (
"io"
"sync"
"github.com/jackc/pgx/v5/internal/iobufpool"
)
const (
StatusStopped = iota
StatusRunning
StatusStopping
)
// BGReader is an io.Reader that can optionally buffer reads in the background. It is safe for concurrent use.
type BGReader struct {
r io.Reader
cond *sync.Cond
status int32
readResults []readResult
}
type readResult struct {
buf *[]byte
err error
}
// Start starts the backgrounder reader. If the background reader is already running this is a no-op. The background
// reader will stop automatically when the underlying reader returns an error.
func (r *BGReader) Start() {
r.cond.L.Lock()
defer r.cond.L.Unlock()
switch r.status {
case StatusStopped:
r.status = StatusRunning
go r.bgRead()
case StatusRunning:
// no-op
case StatusStopping:
r.status = StatusRunning
}
}
// Stop tells the background reader to stop after the in progress Read returns. It is safe to call Stop when the
// background reader is not running.
func (r *BGReader) Stop() {
r.cond.L.Lock()
defer r.cond.L.Unlock()
switch r.status {
case StatusStopped:
// no-op
case StatusRunning:
r.status = StatusStopping
case StatusStopping:
// no-op
}
}
// Status returns the current status of the background reader.
func (r *BGReader) Status() int32 {
r.cond.L.Lock()
defer r.cond.L.Unlock()
return r.status
}
func (r *BGReader) bgRead() {
keepReading := true
for keepReading {
buf := iobufpool.Get(8192)
n, err := r.r.Read(*buf)
*buf = (*buf)[:n]
r.cond.L.Lock()
r.readResults = append(r.readResults, readResult{buf: buf, err: err})
if r.status == StatusStopping || err != nil {
r.status = StatusStopped
keepReading = false
}
r.cond.L.Unlock()
r.cond.Broadcast()
}
}
// Read implements the io.Reader interface.
func (r *BGReader) Read(p []byte) (int, error) {
r.cond.L.Lock()
defer r.cond.L.Unlock()
if len(r.readResults) > 0 {
return r.readFromReadResults(p)
}
// There are no unread background read results and the background reader is stopped.
if r.status == StatusStopped {
return r.r.Read(p)
}
// Wait for results from the background reader
for len(r.readResults) == 0 {
r.cond.Wait()
}
return r.readFromReadResults(p)
}
// readBackgroundResults reads a result previously read by the background reader. r.cond.L must be held.
func (r *BGReader) readFromReadResults(p []byte) (int, error) {
buf := r.readResults[0].buf
var err error
n := copy(p, *buf)
if n == len(*buf) {
err = r.readResults[0].err
iobufpool.Put(buf)
if len(r.readResults) == 1 {
r.readResults = nil
} else {
r.readResults = r.readResults[1:]
}
} else {
*buf = (*buf)[n:]
r.readResults[0].buf = buf
}
return n, err
}
func New(r io.Reader) *BGReader {
return &BGReader{
r: r,
cond: &sync.Cond{
L: &sync.Mutex{},
},
}
}

View File

@ -0,0 +1,140 @@
package bgreader_test
import (
"bytes"
"errors"
"io"
"math/rand"
"testing"
"time"
"github.com/jackc/pgx/v5/pgconn/internal/bgreader"
"github.com/stretchr/testify/require"
)
func TestBGReaderReadWhenStopped(t *testing.T) {
r := bytes.NewReader([]byte("foo bar baz"))
bgr := bgreader.New(r)
buf, err := io.ReadAll(bgr)
require.NoError(t, err)
require.Equal(t, []byte("foo bar baz"), buf)
}
func TestBGReaderReadWhenStarted(t *testing.T) {
r := bytes.NewReader([]byte("foo bar baz"))
bgr := bgreader.New(r)
bgr.Start()
buf, err := io.ReadAll(bgr)
require.NoError(t, err)
require.Equal(t, []byte("foo bar baz"), buf)
}
type mockReadFunc func(p []byte) (int, error)
type mockReader struct {
readFuncs []mockReadFunc
}
func (r *mockReader) Read(p []byte) (int, error) {
if len(r.readFuncs) == 0 {
return 0, io.EOF
}
fn := r.readFuncs[0]
r.readFuncs = r.readFuncs[1:]
return fn(p)
}
func TestBGReaderReadWaitsForBackgroundRead(t *testing.T) {
rr := &mockReader{
readFuncs: []mockReadFunc{
func(p []byte) (int, error) { time.Sleep(1 * time.Second); return copy(p, []byte("foo")), nil },
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
func(p []byte) (int, error) { return copy(p, []byte("baz")), nil },
},
}
bgr := bgreader.New(rr)
bgr.Start()
buf := make([]byte, 3)
n, err := bgr.Read(buf)
require.NoError(t, err)
require.EqualValues(t, 3, n)
require.Equal(t, []byte("foo"), buf)
}
func TestBGReaderErrorWhenStarted(t *testing.T) {
rr := &mockReader{
readFuncs: []mockReadFunc{
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
},
}
bgr := bgreader.New(rr)
bgr.Start()
buf, err := io.ReadAll(bgr)
require.Equal(t, []byte("foobarbaz"), buf)
require.EqualError(t, err, "oops")
}
func TestBGReaderErrorWhenStopped(t *testing.T) {
rr := &mockReader{
readFuncs: []mockReadFunc{
func(p []byte) (int, error) { return copy(p, []byte("foo")), nil },
func(p []byte) (int, error) { return copy(p, []byte("bar")), nil },
func(p []byte) (int, error) { return copy(p, []byte("baz")), errors.New("oops") },
},
}
bgr := bgreader.New(rr)
buf, err := io.ReadAll(bgr)
require.Equal(t, []byte("foobarbaz"), buf)
require.EqualError(t, err, "oops")
}
type numberReader struct {
v uint8
rng *rand.Rand
}
func (nr *numberReader) Read(p []byte) (int, error) {
n := nr.rng.Intn(len(p))
for i := 0; i < n; i++ {
p[i] = nr.v
nr.v++
}
return n, nil
}
// TestBGReaderStress stress tests BGReader by reading a lot of bytes in random sizes while randomly starting and
// stopping the background worker from other goroutines.
func TestBGReaderStress(t *testing.T) {
nr := &numberReader{rng: rand.New(rand.NewSource(0))}
bgr := bgreader.New(nr)
bytesRead := 0
var expected uint8
buf := make([]byte, 10_000)
rng := rand.New(rand.NewSource(0))
for bytesRead < 1_000_000 {
randomNumber := rng.Intn(100)
switch {
case randomNumber < 10:
go bgr.Start()
case randomNumber < 20:
go bgr.Stop()
default:
n, err := bgr.Read(buf)
require.NoError(t, err)
for i := 0; i < n; i++ {
require.Equal(t, expected, buf[i])
expected++
}
bytesRead += n
}
}
}

View File

@ -63,7 +63,7 @@ func (c *PgConn) gssAuth() error {
Data: nextData, Data: nextData,
} }
c.frontend.Send(gssResponse) c.frontend.Send(gssResponse)
err = c.frontend.Flush() err = c.flushWithPotentialWriteReadDeadlock()
if err != nil { if err != nil {
return err return err
} }

File diff suppressed because it is too large Load Diff

View File

@ -14,7 +14,7 @@ import (
) )
func TestConnStress(t *testing.T) { func TestConnStress(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_CONN_STRING")) pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err) require.NoError(t, err)
defer closeConn(t, pgConn) defer closeConn(t, pgConn)

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
# pgproto3 # pgproto3
Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. Package pgproto3 is an encoder and decoder of the PostgreSQL wire protocol version 3.
pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more. pgproto3 can be used as a foundation for PostgreSQL drivers, proxies, mock servers, load balancers and more.

View File

@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte { func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword) dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
return dst return finishMessage(dst, sp)
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -27,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error {
return nil return nil
} }
func (a *AuthenticationGSS) Encode(dst []byte) []byte { func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendInt32(dst, 4)
dst = pgio.AppendUint32(dst, AuthTypeGSS) dst = pgio.AppendUint32(dst, AuthTypeGSS)
return dst return finishMessage(dst, sp)
} }
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) { func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {

View File

@ -31,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error {
return nil return nil
} }
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte { func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
dst = pgio.AppendUint32(dst, AuthTypeGSSCont) dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
dst = append(dst, a.Data...) dst = append(dst, a.Data...)
return dst return finishMessage(dst, sp)
} }
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) { func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {

View File

@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte { func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendInt32(dst, 12)
dst = pgio.AppendUint32(dst, AuthTypeMD5Password) dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
dst = append(dst, src.Salt[:]...) dst = append(dst, src.Salt[:]...)
return dst return finishMessage(dst, sp)
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationOk) Encode(dst []byte) []byte { func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
dst = pgio.AppendInt32(dst, 8)
dst = pgio.AppendUint32(dst, AuthTypeOk) dst = pgio.AppendUint32(dst, AuthTypeOk)
return dst return finishMessage(dst, sp)
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -47,10 +47,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASL) Encode(dst []byte) []byte { func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASL) dst = pgio.AppendUint32(dst, AuthTypeSASL)
for _, s := range src.AuthMechanisms { for _, s := range src.AuthMechanisms {
@ -59,9 +57,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte {
} }
dst = append(dst, 0) dst = append(dst, 0)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte { func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue) dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
dst = append(dst, src.Data...) dst = append(dst, src.Data...)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte { func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'R') dst, sp := beginMessage(dst, 'R')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal) dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
dst = append(dst, src.Data...) dst = append(dst, src.Data...)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Unmarshaler. // MarshalJSON implements encoding/json.Unmarshaler.

View File

@ -17,6 +17,7 @@ type Backend struct {
tracer *tracer tracer *tracer
wbuf []byte wbuf []byte
encodeError error
// Frontend message flyweights // Frontend message flyweights
bind Bind bind Bind
@ -38,6 +39,7 @@ type Backend struct {
terminate Terminate terminate Terminate
bodyLen int bodyLen int
maxBodyLen int // maxBodyLen is the maximum length of a message body in octets. If a message body exceeds this length, Receive will return an error.
msgType byte msgType byte
partialMsg bool partialMsg bool
authType uint32 authType uint32
@ -54,11 +56,21 @@ func NewBackend(r io.Reader, w io.Writer) *Backend {
return &Backend{cr: cr, w: w} return &Backend{cr: cr, w: w}
} }
// Send sends a message to the frontend (i.e. the client). The message is not guaranteed to be written until Flush is // Send sends a message to the frontend (i.e. the client). The message is buffered until Flush is called. Any error
// called. // encountered will be returned from Flush.
func (b *Backend) Send(msg BackendMessage) { func (b *Backend) Send(msg BackendMessage) {
if b.encodeError != nil {
return
}
prevLen := len(b.wbuf) prevLen := len(b.wbuf)
b.wbuf = msg.Encode(b.wbuf) newBuf, err := msg.Encode(b.wbuf)
if err != nil {
b.encodeError = err
return
}
b.wbuf = newBuf
if b.tracer != nil { if b.tracer != nil {
b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg) b.tracer.traceMessage('B', int32(len(b.wbuf)-prevLen), msg)
} }
@ -66,6 +78,12 @@ func (b *Backend) Send(msg BackendMessage) {
// Flush writes any pending messages to the frontend (i.e. the client). // Flush writes any pending messages to the frontend (i.e. the client).
func (b *Backend) Flush() error { func (b *Backend) Flush() error {
if err := b.encodeError; err != nil {
b.encodeError = nil
b.wbuf = b.wbuf[:0]
return &writeError{err: err, safeToRetry: true}
}
n, err := b.w.Write(b.wbuf) n, err := b.w.Write(b.wbuf)
const maxLen = 1024 const maxLen = 1024
@ -157,7 +175,16 @@ func (b *Backend) Receive() (FrontendMessage, error) {
} }
b.msgType = header[0] b.msgType = header[0]
b.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4
msgLength := int(binary.BigEndian.Uint32(header[1:]))
if msgLength < 4 {
return nil, fmt.Errorf("invalid message length: %d", msgLength)
}
b.bodyLen = msgLength - 4
if b.maxBodyLen > 0 && b.bodyLen > b.maxBodyLen {
return nil, &ExceededMaxBodyLenErr{b.maxBodyLen, b.bodyLen}
}
b.partialMsg = true b.partialMsg = true
} }
@ -196,7 +223,7 @@ func (b *Backend) Receive() (FrontendMessage, error) {
case AuthTypeCleartextPassword, AuthTypeMD5Password: case AuthTypeCleartextPassword, AuthTypeMD5Password:
fallthrough fallthrough
default: default:
// to maintain backwards compatability // to maintain backwards compatibility
msg = &PasswordMessage{} msg = &PasswordMessage{}
} }
case 'Q': case 'Q':
@ -260,3 +287,13 @@ func (b *Backend) SetAuthType(authType uint32) error {
return nil return nil
} }
// SetMaxBodyLen sets the maximum length of a message body in octets.
// If a message body exceeds this length, Receive will return an error.
// This is useful for protecting against malicious clients that send
// large messages with the intent of causing memory exhaustion.
// The default value is 0.
// If maxBodyLen is 0, then no maximum is enforced.
func (b *Backend) SetMaxBodyLen(maxBodyLen int) {
b.maxBodyLen = maxBodyLen
}

View File

@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BackendKeyData) Encode(dst []byte) []byte { func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'K') dst, sp := beginMessage(dst, 'K')
dst = pgio.AppendUint32(dst, 12)
dst = pgio.AppendUint32(dst, src.ProcessID) dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey) dst = pgio.AppendUint32(dst, src.SecretKey)
return dst return finishMessage(dst, sp)
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -71,8 +71,8 @@ func TestStartupMessage(t *testing.T) {
"username": "tester", "username": "tester",
}, },
} }
dst := []byte{} dst, err := want.Encode([]byte{})
dst = want.Encode(dst) require.NoError(t, err)
server := &interruptReader{} server := &interruptReader{}
server.push(dst) server.push(dst)
@ -120,3 +120,21 @@ func TestStartupMessage(t *testing.T) {
} }
}) })
} }
func TestBackendReceiveExceededMaxBodyLen(t *testing.T) {
t.Parallel()
server := &interruptReader{}
server.push([]byte{'Q', 0, 0, 10, 10})
backend := pgproto3.NewBackend(server, nil)
// Set max body len to 5
backend.SetMaxBodyLen(5)
// Receive regular msg
msg, err := backend.Receive()
assert.Nil(t, msg)
var invalidBodyLenErr *pgproto3.ExceededMaxBodyLenErr
assert.ErrorAs(t, err, &invalidBodyLenErr)
}

View File

@ -5,7 +5,9 @@ import (
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"math"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -108,21 +110,25 @@ func (dst *Bind) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Bind) Encode(dst []byte) []byte { func (src *Bind) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'B') dst, sp := beginMessage(dst, 'B')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.DestinationPortal...) dst = append(dst, src.DestinationPortal...)
dst = append(dst, 0) dst = append(dst, 0)
dst = append(dst, src.PreparedStatement...) dst = append(dst, src.PreparedStatement...)
dst = append(dst, 0) dst = append(dst, 0)
if len(src.ParameterFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many parameter format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes))) dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
for _, fc := range src.ParameterFormatCodes { for _, fc := range src.ParameterFormatCodes {
dst = pgio.AppendInt16(dst, fc) dst = pgio.AppendInt16(dst, fc)
} }
if len(src.Parameters) > math.MaxUint16 {
return nil, errors.New("too many parameters")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters))) dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
for _, p := range src.Parameters { for _, p := range src.Parameters {
if p == nil { if p == nil {
@ -134,14 +140,15 @@ func (src *Bind) Encode(dst []byte) []byte {
dst = append(dst, p...) dst = append(dst, p...)
} }
if len(src.ResultFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many result format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes))) dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
for _, fc := range src.ResultFormatCodes { for _, fc := range src.ResultFormatCodes {
dst = pgio.AppendInt16(dst, fc) dst = pgio.AppendInt16(dst, fc)
} }
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *BindComplete) Encode(dst []byte) []byte { func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '2', 0, 0, 0, 4) return append(dst, '2', 0, 0, 0, 4), nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

20
pgproto3/bind_test.go Normal file
View File

@ -0,0 +1,20 @@
package pgproto3_test
import (
"testing"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/stretchr/testify/require"
)
func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) {
t.Parallel()
// Maximum allowed size.
_, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil)
require.NoError(t, err)
// 1 byte too big
_, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil)
require.Error(t, err)
}

View File

@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 4 byte message length. // Encode encodes src into dst. dst will include the 4 byte message length.
func (src *CancelRequest) Encode(dst []byte) []byte { func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
dst = pgio.AppendInt32(dst, 16) dst = pgio.AppendInt32(dst, 16)
dst = pgio.AppendInt32(dst, cancelRequestCode) dst = pgio.AppendInt32(dst, cancelRequestCode)
dst = pgio.AppendUint32(dst, src.ProcessID) dst = pgio.AppendUint32(dst, src.ProcessID)
dst = pgio.AppendUint32(dst, src.SecretKey) dst = pgio.AppendUint32(dst, src.SecretKey)
return dst return dst, nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -14,7 +14,7 @@ import (
type chunkReader struct { type chunkReader struct {
r io.Reader r io.Reader
buf []byte buf *[]byte
rp, wp int // buf read position and write position rp, wp int // buf read position and write position
minBufSize int minBufSize int
@ -45,7 +45,7 @@ func newChunkReader(r io.Reader, minBufSize int) *chunkReader {
func (r *chunkReader) Next(n int) (buf []byte, err error) { func (r *chunkReader) Next(n int) (buf []byte, err error) {
// Reset the buffer if it is empty // Reset the buffer if it is empty
if r.rp == r.wp { if r.rp == r.wp {
if len(r.buf) != r.minBufSize { if len(*r.buf) != r.minBufSize {
iobufpool.Put(r.buf) iobufpool.Put(r.buf)
r.buf = iobufpool.Get(r.minBufSize) r.buf = iobufpool.Get(r.minBufSize)
} }
@ -55,15 +55,15 @@ func (r *chunkReader) Next(n int) (buf []byte, err error) {
// n bytes already in buf // n bytes already in buf
if (r.wp - r.rp) >= n { if (r.wp - r.rp) >= n {
buf = r.buf[r.rp : r.rp+n : r.rp+n] buf = (*r.buf)[r.rp : r.rp+n : r.rp+n]
r.rp += n r.rp += n
return buf, err return buf, err
} }
// buf is smaller than requested number of bytes // buf is smaller than requested number of bytes
if len(r.buf) < n { if len(*r.buf) < n {
bigBuf := iobufpool.Get(n) bigBuf := iobufpool.Get(n)
r.wp = copy(bigBuf, r.buf[r.rp:r.wp]) r.wp = copy((*bigBuf), (*r.buf)[r.rp:r.wp])
r.rp = 0 r.rp = 0
iobufpool.Put(r.buf) iobufpool.Put(r.buf)
r.buf = bigBuf r.buf = bigBuf
@ -71,20 +71,20 @@ func (r *chunkReader) Next(n int) (buf []byte, err error) {
// buf is large enough, but need to shift filled area to start to make enough contiguous space // buf is large enough, but need to shift filled area to start to make enough contiguous space
minReadCount := n - (r.wp - r.rp) minReadCount := n - (r.wp - r.rp)
if (len(r.buf) - r.wp) < minReadCount { if (len(*r.buf) - r.wp) < minReadCount {
r.wp = copy(r.buf, r.buf[r.rp:r.wp]) r.wp = copy((*r.buf), (*r.buf)[r.rp:r.wp])
r.rp = 0 r.rp = 0
} }
// Read at least the required number of bytes from the underlying io.Reader // Read at least the required number of bytes from the underlying io.Reader
readBytesCount, err := io.ReadAtLeast(r.r, r.buf[r.wp:], minReadCount) readBytesCount, err := io.ReadAtLeast(r.r, (*r.buf)[r.wp:], minReadCount)
r.wp += readBytesCount r.wp += readBytesCount
// fmt.Println("read", n) // fmt.Println("read", n)
if err != nil { if err != nil {
return nil, err return nil, err
} }
buf = r.buf[r.rp : r.rp+n : r.rp+n] buf = (*r.buf)[r.rp : r.rp+n : r.rp+n]
r.rp += n r.rp += n
return buf, nil return buf, nil
} }

View File

@ -17,7 +17,7 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if bytes.Compare(n1, src[0:2]) != 0 { if !bytes.Equal(n1, src[0:2]) {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1) t.Fatalf("Expected read bytes to be %v, but they were %v", src[0:2], n1)
} }
@ -25,11 +25,11 @@ func TestChunkReaderNextDoesNotReadIfAlreadyBuffered(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if bytes.Compare(n2, src[2:4]) != 0 { if !bytes.Equal(n2, src[2:4]) {
t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2) t.Fatalf("Expected read bytes to be %v, but they were %v", src[2:4], n2)
} }
if bytes.Compare(r.buf[:len(src)], src) != 0 { if !bytes.Equal((*r.buf)[:len(src)], src) {
t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf) t.Fatalf("Expected r.buf to be %v, but it was %v", src, r.buf)
} }

View File

@ -4,8 +4,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"github.com/jackc/pgx/v5/internal/pgio"
) )
type Close struct { type Close struct {
@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Close) Encode(dst []byte) []byte { func (src *Close) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'C') dst, sp := beginMessage(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.ObjectType) dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...) dst = append(dst, src.Name...)
dst = append(dst, 0) dst = append(dst, 0)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CloseComplete) Encode(dst []byte) []byte { func (src *CloseComplete) Encode(dst []byte) ([]byte, error) {
return append(dst, '3', 0, 0, 0, 4) return append(dst, '3', 0, 0, 0, 4), nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -3,8 +3,6 @@ package pgproto3
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
) )
type CommandComplete struct { type CommandComplete struct {
@ -31,17 +29,11 @@ func (dst *CommandComplete) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CommandComplete) Encode(dst []byte) []byte { func (src *CommandComplete) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'C') dst, sp := beginMessage(dst, 'C')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.CommandTag...) dst = append(dst, src.CommandTag...)
dst = append(dst, 0) dst = append(dst, 0)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors" "errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -44,19 +45,18 @@ func (dst *CopyBothResponse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyBothResponse) Encode(dst []byte) []byte { func (src *CopyBothResponse) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'W') dst, sp := beginMessage(dst, 'W')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat) dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes { for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc) dst = pgio.AppendUint16(dst, fc)
} }
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -5,6 +5,7 @@ import (
"github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgproto3"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestEncodeDecode(t *testing.T) { func TestEncodeDecode(t *testing.T) {
@ -13,6 +14,7 @@ func TestEncodeDecode(t *testing.T) {
err := dstResp.Decode(srcBytes[5:]) err := dstResp.Decode(srcBytes[5:])
assert.NoError(t, err, "No errors on decode") assert.NoError(t, err, "No errors on decode")
dstBytes := []byte{} dstBytes := []byte{}
dstBytes = dstResp.Encode(dstBytes) dstBytes, err = dstResp.Encode(dstBytes)
require.NoError(t, err)
assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match") assert.EqualValues(t, srcBytes, dstBytes, "Expecting src & dest bytes to match")
} }

View File

@ -3,8 +3,6 @@ package pgproto3
import ( import (
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
) )
type CopyData struct { type CopyData struct {
@ -25,11 +23,10 @@ func (dst *CopyData) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyData) Encode(dst []byte) []byte { func (src *CopyData) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'd') dst, sp := beginMessage(dst, 'd')
dst = pgio.AppendInt32(dst, int32(4+len(src.Data)))
dst = append(dst, src.Data...) dst = append(dst, src.Data...)
return dst return finishMessage(dst, sp)
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -24,8 +24,8 @@ func (dst *CopyDone) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyDone) Encode(dst []byte) []byte { func (src *CopyDone) Encode(dst []byte) ([]byte, error) {
return append(dst, 'c', 0, 0, 0, 4) return append(dst, 'c', 0, 0, 0, 4), nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -3,8 +3,6 @@ package pgproto3
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"github.com/jackc/pgx/v5/internal/pgio"
) )
type CopyFail struct { type CopyFail struct {
@ -28,17 +26,11 @@ func (dst *CopyFail) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyFail) Encode(dst []byte) []byte { func (src *CopyFail) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'f') dst, sp := beginMessage(dst, 'f')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Message...) dst = append(dst, src.Message...)
dst = append(dst, 0) dst = append(dst, 0)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors" "errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -44,20 +45,19 @@ func (dst *CopyInResponse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyInResponse) Encode(dst []byte) []byte { func (src *CopyInResponse) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'G') dst, sp := beginMessage(dst, 'G')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat) dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes { for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc) dst = pgio.AppendUint16(dst, fc)
} }
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"encoding/json" "encoding/json"
"errors" "errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -43,21 +44,20 @@ func (dst *CopyOutResponse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *CopyOutResponse) Encode(dst []byte) []byte { func (src *CopyOutResponse) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'H') dst, sp := beginMessage(dst, 'H')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.OverallFormat) dst = append(dst, src.OverallFormat)
if len(src.ColumnFormatCodes) > math.MaxUint16 {
return nil, errors.New("too many column format codes")
}
dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes))) dst = pgio.AppendUint16(dst, uint16(len(src.ColumnFormatCodes)))
for _, fc := range src.ColumnFormatCodes { for _, fc := range src.ColumnFormatCodes {
dst = pgio.AppendUint16(dst, fc) dst = pgio.AppendUint16(dst, fc)
} }
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -4,6 +4,8 @@ import (
"encoding/binary" "encoding/binary"
"encoding/hex" "encoding/hex"
"encoding/json" "encoding/json"
"errors"
"math"
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
@ -63,11 +65,12 @@ func (dst *DataRow) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *DataRow) Encode(dst []byte) []byte { func (src *DataRow) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'D') dst, sp := beginMessage(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
if len(src.Values) > math.MaxUint16 {
return nil, errors.New("too many values")
}
dst = pgio.AppendUint16(dst, uint16(len(src.Values))) dst = pgio.AppendUint16(dst, uint16(len(src.Values)))
for _, v := range src.Values { for _, v := range src.Values {
if v == nil { if v == nil {
@ -79,9 +82,7 @@ func (src *DataRow) Encode(dst []byte) []byte {
dst = append(dst, v...) dst = append(dst, v...)
} }
pgio.SetInt32(dst[sp:], int32(len(dst[sp:]))) return finishMessage(dst, sp)
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -4,8 +4,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"github.com/jackc/pgx/v5/internal/pgio"
) )
type Describe struct { type Describe struct {
@ -37,18 +35,12 @@ func (dst *Describe) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Describe) Encode(dst []byte) []byte { func (src *Describe) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'D') dst, sp := beginMessage(dst, 'D')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.ObjectType) dst = append(dst, src.ObjectType)
dst = append(dst, src.Name...) dst = append(dst, src.Name...)
dst = append(dst, 0) dst = append(dst, 0)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -1,7 +1,7 @@
// Package pgproto3 is a encoder and decoder of the PostgreSQL wire protocol version 3. // Package pgproto3 is an encoder and decoder of the PostgreSQL wire protocol version 3.
// //
// The primary interfaces are Frontend and Backend. They correspond to a client and server respectively. Messages are // The primary interfaces are Frontend and Backend. They correspond to a client and server respectively. Messages are
// sent with Send (or a specialized Send variant). Messages are automatically bufferred to minimize small writes. Call // sent with Send (or a specialized Send variant). Messages are automatically buffered to minimize small writes. Call
// Flush to ensure a message has actually been sent. // Flush to ensure a message has actually been sent.
// //
// The Trace method of Frontend and Backend can be used to examine the wire-level message traffic. It outputs in a // The Trace method of Frontend and Backend can be used to examine the wire-level message traffic. It outputs in a

View File

@ -20,8 +20,8 @@ func (dst *EmptyQueryResponse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *EmptyQueryResponse) Encode(dst []byte) []byte { func (src *EmptyQueryResponse) Encode(dst []byte) ([]byte, error) {
return append(dst, 'I', 0, 0, 0, 4) return append(dst, 'I', 0, 0, 0, 4), nil
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -2,7 +2,6 @@ package pgproto3
import ( import (
"bytes" "bytes"
"encoding/binary"
"encoding/json" "encoding/json"
"strconv" "strconv"
) )
@ -111,120 +110,113 @@ func (dst *ErrorResponse) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *ErrorResponse) Encode(dst []byte) []byte { func (src *ErrorResponse) Encode(dst []byte) ([]byte, error) {
return append(dst, src.marshalBinary('E')...) dst, sp := beginMessage(dst, 'E')
dst = src.appendFields(dst)
return finishMessage(dst, sp)
} }
func (src *ErrorResponse) marshalBinary(typeByte byte) []byte { func (src *ErrorResponse) appendFields(dst []byte) []byte {
var bigEndian BigEndianBuf
buf := &bytes.Buffer{}
buf.WriteByte(typeByte)
buf.Write(bigEndian.Uint32(0))
if src.Severity != "" { if src.Severity != "" {
buf.WriteByte('S') dst = append(dst, 'S')
buf.WriteString(src.Severity) dst = append(dst, src.Severity...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.SeverityUnlocalized != "" { if src.SeverityUnlocalized != "" {
buf.WriteByte('V') dst = append(dst, 'V')
buf.WriteString(src.SeverityUnlocalized) dst = append(dst, src.SeverityUnlocalized...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Code != "" { if src.Code != "" {
buf.WriteByte('C') dst = append(dst, 'C')
buf.WriteString(src.Code) dst = append(dst, src.Code...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Message != "" { if src.Message != "" {
buf.WriteByte('M') dst = append(dst, 'M')
buf.WriteString(src.Message) dst = append(dst, src.Message...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Detail != "" { if src.Detail != "" {
buf.WriteByte('D') dst = append(dst, 'D')
buf.WriteString(src.Detail) dst = append(dst, src.Detail...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Hint != "" { if src.Hint != "" {
buf.WriteByte('H') dst = append(dst, 'H')
buf.WriteString(src.Hint) dst = append(dst, src.Hint...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Position != 0 { if src.Position != 0 {
buf.WriteByte('P') dst = append(dst, 'P')
buf.WriteString(strconv.Itoa(int(src.Position))) dst = append(dst, strconv.Itoa(int(src.Position))...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.InternalPosition != 0 { if src.InternalPosition != 0 {
buf.WriteByte('p') dst = append(dst, 'p')
buf.WriteString(strconv.Itoa(int(src.InternalPosition))) dst = append(dst, strconv.Itoa(int(src.InternalPosition))...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.InternalQuery != "" { if src.InternalQuery != "" {
buf.WriteByte('q') dst = append(dst, 'q')
buf.WriteString(src.InternalQuery) dst = append(dst, src.InternalQuery...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Where != "" { if src.Where != "" {
buf.WriteByte('W') dst = append(dst, 'W')
buf.WriteString(src.Where) dst = append(dst, src.Where...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.SchemaName != "" { if src.SchemaName != "" {
buf.WriteByte('s') dst = append(dst, 's')
buf.WriteString(src.SchemaName) dst = append(dst, src.SchemaName...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.TableName != "" { if src.TableName != "" {
buf.WriteByte('t') dst = append(dst, 't')
buf.WriteString(src.TableName) dst = append(dst, src.TableName...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.ColumnName != "" { if src.ColumnName != "" {
buf.WriteByte('c') dst = append(dst, 'c')
buf.WriteString(src.ColumnName) dst = append(dst, src.ColumnName...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.DataTypeName != "" { if src.DataTypeName != "" {
buf.WriteByte('d') dst = append(dst, 'd')
buf.WriteString(src.DataTypeName) dst = append(dst, src.DataTypeName...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.ConstraintName != "" { if src.ConstraintName != "" {
buf.WriteByte('n') dst = append(dst, 'n')
buf.WriteString(src.ConstraintName) dst = append(dst, src.ConstraintName...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.File != "" { if src.File != "" {
buf.WriteByte('F') dst = append(dst, 'F')
buf.WriteString(src.File) dst = append(dst, src.File...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Line != 0 { if src.Line != 0 {
buf.WriteByte('L') dst = append(dst, 'L')
buf.WriteString(strconv.Itoa(int(src.Line))) dst = append(dst, strconv.Itoa(int(src.Line))...)
buf.WriteByte(0) dst = append(dst, 0)
} }
if src.Routine != "" { if src.Routine != "" {
buf.WriteByte('R') dst = append(dst, 'R')
buf.WriteString(src.Routine) dst = append(dst, src.Routine...)
buf.WriteByte(0) dst = append(dst, 0)
} }
for k, v := range src.UnknownFields { for k, v := range src.UnknownFields {
buf.WriteByte(k) dst = append(dst, k)
buf.WriteByte(0) dst = append(dst, v...)
buf.WriteString(v) dst = append(dst, 0)
buf.WriteByte(0)
} }
buf.WriteByte(0) dst = append(dst, 0)
binary.BigEndian.PutUint32(buf.Bytes()[1:5], uint32(buf.Len()-1)) return dst
return buf.Bytes()
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

View File

@ -46,7 +46,7 @@ func (p *PgFortuneBackend) Run() error {
return fmt.Errorf("error generating query response: %w", err) return fmt.Errorf("error generating query response: %w", err)
} }
buf := (&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{ buf := mustEncode((&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
{ {
Name: []byte("fortune"), Name: []byte("fortune"),
TableOID: 0, TableOID: 0,
@ -56,10 +56,10 @@ func (p *PgFortuneBackend) Run() error {
TypeModifier: -1, TypeModifier: -1,
Format: 0, Format: 0,
}, },
}}).Encode(nil) }}).Encode(nil))
buf = (&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf) buf = mustEncode((&pgproto3.DataRow{Values: [][]byte{response}}).Encode(buf))
buf = (&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf) buf = mustEncode((&pgproto3.CommandComplete{CommandTag: []byte("SELECT 1")}).Encode(buf))
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf) buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
_, err = p.conn.Write(buf) _, err = p.conn.Write(buf)
if err != nil { if err != nil {
return fmt.Errorf("error writing query response: %w", err) return fmt.Errorf("error writing query response: %w", err)
@ -80,8 +80,8 @@ func (p *PgFortuneBackend) handleStartup() error {
switch startupMessage.(type) { switch startupMessage.(type) {
case *pgproto3.StartupMessage: case *pgproto3.StartupMessage:
buf := (&pgproto3.AuthenticationOk{}).Encode(nil) buf := mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil))
buf = (&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf) buf = mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(buf))
_, err = p.conn.Write(buf) _, err = p.conn.Write(buf)
if err != nil { if err != nil {
return fmt.Errorf("error sending ready for query: %w", err) return fmt.Errorf("error sending ready for query: %w", err)
@ -102,3 +102,10 @@ func (p *PgFortuneBackend) handleStartup() error {
func (p *PgFortuneBackend) Close() error { func (p *PgFortuneBackend) Close() error {
return p.conn.Close() return p.conn.Close()
} }
func mustEncode(buf []byte, err error) []byte {
if err != nil {
panic(err)
}
return buf
}

View File

@ -36,19 +36,12 @@ func (dst *Execute) Decode(src []byte) error {
} }
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length. // Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
func (src *Execute) Encode(dst []byte) []byte { func (src *Execute) Encode(dst []byte) ([]byte, error) {
dst = append(dst, 'E') dst, sp := beginMessage(dst, 'E')
sp := len(dst)
dst = pgio.AppendInt32(dst, -1)
dst = append(dst, src.Portal...) dst = append(dst, src.Portal...)
dst = append(dst, 0) dst = append(dst, 0)
dst = pgio.AppendUint32(dst, src.MaxRows) dst = pgio.AppendUint32(dst, src.MaxRows)
return finishMessage(dst, sp)
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
return dst
} }
// MarshalJSON implements encoding/json.Marshaler. // MarshalJSON implements encoding/json.Marshaler.

Some files were not shown because too many files have changed in this diff Show More