mirror of https://github.com/jackc/pgx.git
Compare commits
84 Commits
Author | SHA1 | Date |
---|---|---|
|
0f77a2d028 | |
|
ddd966f09f | |
|
924834b5b4 | |
|
9b15554c51 | |
|
037e4cf9a2 | |
|
04bcc0219d | |
|
0e0a7d8344 | |
|
63422c7d6c | |
|
5c1fbf4806 | |
|
05fe5f8b05 | |
|
70c9a147a2 | |
|
6603ddfbe4 | |
|
70f7cad222 | |
|
6bf1b0b1b9 | |
|
14bda65a0c | |
|
9e3c4fb40f | |
|
05e72a5ab1 | |
|
47d631e34b | |
|
58b05f567c | |
|
dcb7193669 | |
|
1abf7d9050 | |
|
b5efc90a32 | |
|
a26c93551f | |
|
2100e1da46 | |
|
2d21a2b80d | |
|
5f33ee5f07 | |
|
228cfffc20 | |
|
a5353af354 | |
|
0bc29e3000 | |
|
9cce05944a | |
|
9c0ad690a9 | |
|
03f08abda3 | |
|
2c1b1c389a | |
|
329cb45913 | |
|
c96a55f8c0 | |
|
e87760682f | |
|
f681632c68 | |
|
3c640a44b6 | |
|
de3f868c1d | |
|
5424d3c873 | |
|
42d3d00734 | |
|
cdc672cf3f | |
|
52e2858629 | |
|
e352784fed | |
|
c2175fe46e | |
|
659823f8f3 | |
|
ca04098fab | |
|
4ff0a454e0 | |
|
00b86ca3db | |
|
61a0227241 | |
|
2190a8e0d1 | |
|
6e9fa42fef | |
|
6d9e6a726e | |
|
02e387ea64 | |
|
e452f80b1d | |
|
da0315d1a4 | |
|
120c89fe0d | |
|
057937db27 | |
|
47cbd8edb8 | |
|
90a77b13b2 | |
|
59d6aa87b9 | |
|
39ffc8b7a4 | |
|
c4c1076d28 | |
|
4293b25262 | |
|
ea1e13a660 | |
|
58d4c0c94f | |
|
1752f7b4c1 | |
|
ee718a110d | |
|
546ad2f4e2 | |
|
efc2c9ff44 | |
|
aabed18db8 | |
|
afa974fb05 | |
|
12b37f3218 | |
|
bcf3fbd780 | |
|
f7c3d190ad | |
|
473a241b96 | |
|
311f72afdc | |
|
877111ceeb | |
|
dc3aea06b5 | |
|
e5d321f920 | |
|
17cd36818c | |
|
76593f37f7 | |
|
29751194ef | |
|
c1f4cbb5cd |
|
@ -14,18 +14,8 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
go-version: ["1.22", "1.23"]
|
go-version: ["1.22", "1.23"]
|
||||||
pg-version: [12, 13, 14, 15, 16, cockroachdb]
|
pg-version: [13, 14, 15, 16, 17, cockroachdb]
|
||||||
include:
|
include:
|
||||||
- pg-version: 12
|
|
||||||
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
|
||||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
|
||||||
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
|
||||||
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
|
|
||||||
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
|
||||||
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
|
|
||||||
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
|
||||||
pgx-ssl-password: certpw
|
|
||||||
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
|
||||||
- pg-version: 13
|
- pg-version: 13
|
||||||
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||||
|
@ -66,6 +56,16 @@ jobs:
|
||||||
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
||||||
pgx-ssl-password: certpw
|
pgx-ssl-password: certpw
|
||||||
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
||||||
|
- pg-version: 17
|
||||||
|
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||||
|
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||||
|
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||||
|
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
|
||||||
|
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||||
|
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
|
||||||
|
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
||||||
|
pgx-ssl-password: certpw
|
||||||
|
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
||||||
- pg-version: cockroachdb
|
- pg-version: cockroachdb
|
||||||
pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
|
pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
|
||||||
|
|
||||||
|
|
18
CHANGELOG.md
18
CHANGELOG.md
|
@ -1,3 +1,20 @@
|
||||||
|
# 5.7.4 (March 24, 2025)
|
||||||
|
|
||||||
|
* Fix / revert change to scanning JSON `null` (Felix Röhrich)
|
||||||
|
|
||||||
|
# 5.7.3 (March 21, 2025)
|
||||||
|
|
||||||
|
* Expose EmptyAcquireWaitTime in pgxpool.Stat (vamshiaruru32)
|
||||||
|
* Improve SQL sanitizer performance (ninedraft)
|
||||||
|
* Fix Scan confusion with json(b), sql.Scanner, and automatic dereferencing (moukoublen, felix-roehrich)
|
||||||
|
* Fix Values() for xml type always returning nil instead of []byte
|
||||||
|
* Add ability to send Flush message in pipeline mode (zenkovev)
|
||||||
|
* Fix pgtype.Timestamp's JSON behavior to match PostgreSQL (pconstantinou)
|
||||||
|
* Better error messages when scanning structs (logicbomb)
|
||||||
|
* Fix handling of error on batch write (bonnefoa)
|
||||||
|
* Match libpq's connection fallback behavior more closely (felix-roehrich)
|
||||||
|
* Add MinIdleConns to pgxpool (djahandarie)
|
||||||
|
|
||||||
# 5.7.2 (December 21, 2024)
|
# 5.7.2 (December 21, 2024)
|
||||||
|
|
||||||
* Fix prepared statement already exists on batch prepare failure
|
* Fix prepared statement already exists on batch prepare failure
|
||||||
|
@ -9,6 +26,7 @@
|
||||||
* Implement pgtype.UUID.String() (Konstantin Grachev)
|
* Implement pgtype.UUID.String() (Konstantin Grachev)
|
||||||
* Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions (Alexander Rumyantsev)
|
* Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions (Alexander Rumyantsev)
|
||||||
* Update golang.org/x/crypto
|
* Update golang.org/x/crypto
|
||||||
|
* Fix json(b) columns prefer sql.Scanner interface like database/sql (Ludovico Russo)
|
||||||
|
|
||||||
# 5.7.1 (September 10, 2024)
|
# 5.7.1 (September 10, 2024)
|
||||||
|
|
||||||
|
|
14
README.md
14
README.md
|
@ -92,7 +92,7 @@ See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.
|
||||||
|
|
||||||
## Supported Go and PostgreSQL Versions
|
## Supported Go and PostgreSQL Versions
|
||||||
|
|
||||||
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.21 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
|
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.22 and higher and PostgreSQL 13 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
|
||||||
|
|
||||||
## Version Policy
|
## Version Policy
|
||||||
|
|
||||||
|
@ -172,3 +172,15 @@ Supports, structs, maps, slices and custom mapping functions.
|
||||||
### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx)
|
### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx)
|
||||||
|
|
||||||
Code first migration library for native pgx (no database/sql abstraction).
|
Code first migration library for native pgx (no database/sql abstraction).
|
||||||
|
|
||||||
|
### [github.com/amirsalarsafaei/sqlc-pgx-monitoring](https://github.com/amirsalarsafaei/sqlc-pgx-monitoring)
|
||||||
|
|
||||||
|
A database monitoring/metrics library for pgx and sqlc. Trace, log and monitor your sqlc query performance using OpenTelemetry.
|
||||||
|
|
||||||
|
### [https://github.com/nikolayk812/pgx-outbox](https://github.com/nikolayk812/pgx-outbox)
|
||||||
|
|
||||||
|
Simple Golang implementation for transactional outbox pattern for PostgreSQL using jackc/pgx driver.
|
||||||
|
|
||||||
|
### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy)
|
||||||
|
|
||||||
|
Simplifies working with the pgx library, providing convenient scanning of nested structures.
|
||||||
|
|
2
Rakefile
2
Rakefile
|
@ -2,7 +2,7 @@ require "erb"
|
||||||
|
|
||||||
rule '.go' => '.go.erb' do |task|
|
rule '.go' => '.go.erb' do |task|
|
||||||
erb = ERB.new(File.read(task.source))
|
erb = ERB.new(File.read(task.source))
|
||||||
File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding))
|
File.write(task.name, "// Code generated from #{task.source}. DO NOT EDIT.\n\n" + erb.result(binding))
|
||||||
sh "goimports", "-w", task.name
|
sh "goimports", "-w", task.name
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -42,8 +42,8 @@ fi
|
||||||
|
|
||||||
if [[ "${PGVERSION-}" =~ ^cockroach ]]
|
if [[ "${PGVERSION-}" =~ ^cockroach ]]
|
||||||
then
|
then
|
||||||
wget -qO- https://binaries.cockroachdb.com/cockroach-v23.1.3.linux-amd64.tgz | tar xvz
|
wget -qO- https://binaries.cockroachdb.com/cockroach-v24.3.3.linux-amd64.tgz | tar xvz
|
||||||
sudo mv cockroach-v23.1.3.linux-amd64/cockroach /usr/local/bin/
|
sudo mv cockroach-v24.3.3.linux-amd64/cockroach /usr/local/bin/
|
||||||
cockroach start-single-node --insecure --background --listen-addr=localhost
|
cockroach start-single-node --insecure --background --listen-addr=localhost
|
||||||
cockroach sql --insecure -e 'create database pgx_test'
|
cockroach sql --insecure -e 'create database pgx_test'
|
||||||
fi
|
fi
|
||||||
|
|
21
conn.go
21
conn.go
|
@ -420,7 +420,7 @@ func (c *Conn) IsClosed() bool {
|
||||||
return c.pgConn.IsClosed()
|
return c.pgConn.IsClosed()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) die(err error) {
|
func (c *Conn) die() {
|
||||||
if c.IsClosed() {
|
if c.IsClosed() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -588,14 +588,6 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription
|
||||||
return result.CommandTag, result.Err
|
return result.CommandTag, result.Err
|
||||||
}
|
}
|
||||||
|
|
||||||
type unknownArgumentTypeQueryExecModeExecError struct {
|
|
||||||
arg any
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *unknownArgumentTypeQueryExecModeExecError) Error() string {
|
|
||||||
return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) {
|
func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) {
|
||||||
err := c.eqb.Build(c.typeMap, nil, args)
|
err := c.eqb.Build(c.typeMap, nil, args)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -661,11 +653,12 @@ const (
|
||||||
// should implement pgtype.Int64Valuer.
|
// should implement pgtype.Int64Valuer.
|
||||||
QueryExecModeExec
|
QueryExecModeExec
|
||||||
|
|
||||||
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. Queries
|
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. This is
|
||||||
// are executed in a single round trip. Type mappings can be registered with pgtype.Map.RegisterDefaultPgType. Queries
|
// especially significant for []byte values. []byte values are encoded as PostgreSQL bytea. string must be used
|
||||||
// will be rejected that have arguments that are unregistered or ambiguous. e.g. A map[string]string may have the
|
// instead for text type values including json and jsonb. Type mappings can be registered with
|
||||||
// PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use a map[string]string directly as an
|
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambiguous.
|
||||||
// argument. This mode cannot.
|
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use a
|
||||||
|
// map[string]string directly as an argument. This mode cannot. Queries are executed in a single round trip.
|
||||||
//
|
//
|
||||||
// QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec. This includes
|
// QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec. This includes
|
||||||
// the warning regarding differences in text format and binary format encoding with user defined types. There may be
|
// the warning regarding differences in text format and binary format encoding with user defined types. There may be
|
||||||
|
|
|
@ -1417,5 +1417,4 @@ func TestErrNoRows(t *testing.T) {
|
||||||
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())
|
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())
|
||||||
|
|
||||||
require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
|
require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
|
||||||
require.ErrorIs(t, pgx.ErrNoRows, pgx.ErrNoRows, "sql.ErrNowRows must match pgx.ErrNoRows")
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -161,7 +161,7 @@ type derivedTypeInfo struct {
|
||||||
// The result of this call can be passed into RegisterTypes to complete the process.
|
// The result of this call can be passed into RegisterTypes to complete the process.
|
||||||
func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) {
|
func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) {
|
||||||
m := c.TypeMap()
|
m := c.TypeMap()
|
||||||
if typeNames == nil || len(typeNames) == 0 {
|
if len(typeNames) == 0 {
|
||||||
return nil, fmt.Errorf("No type names were supplied.")
|
return nil, fmt.Errorf("No type names were supplied.")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -169,13 +169,7 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ
|
||||||
// the SQL not support recent structures such as multirange
|
// the SQL not support recent structures such as multirange
|
||||||
serverVersion, _ := serverVersion(c)
|
serverVersion, _ := serverVersion(c)
|
||||||
sql := buildLoadDerivedTypesSQL(serverVersion, typeNames)
|
sql := buildLoadDerivedTypesSQL(serverVersion, typeNames)
|
||||||
var rows Rows
|
rows, err := c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
|
||||||
var err error
|
|
||||||
if typeNames == nil {
|
|
||||||
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol)
|
|
||||||
} else {
|
|
||||||
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("While generating load types query: %w", err)
|
return nil, fmt.Errorf("While generating load types query: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -232,15 +226,15 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName)
|
return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName)
|
||||||
}
|
}
|
||||||
if type_ != nil {
|
|
||||||
m.RegisterType(type_)
|
// the type_ is imposible to be null
|
||||||
if ti.NspName != "" {
|
m.RegisterType(type_)
|
||||||
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}
|
if ti.NspName != "" {
|
||||||
m.RegisterType(nspType)
|
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}
|
||||||
result = append(result, nspType)
|
m.RegisterType(nspType)
|
||||||
}
|
result = append(result, nspType)
|
||||||
result = append(result, type_)
|
|
||||||
}
|
}
|
||||||
|
result = append(result, type_)
|
||||||
}
|
}
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
current_branch=$(git rev-parse --abbrev-ref HEAD)
|
||||||
|
if [ "$current_branch" == "HEAD" ]; then
|
||||||
|
current_branch=$(git rev-parse HEAD)
|
||||||
|
fi
|
||||||
|
|
||||||
|
restore_branch() {
|
||||||
|
echo "Restoring original branch/commit: $current_branch"
|
||||||
|
git checkout "$current_branch"
|
||||||
|
}
|
||||||
|
trap restore_branch EXIT
|
||||||
|
|
||||||
|
# Check if there are uncommitted changes
|
||||||
|
if ! git diff --quiet || ! git diff --cached --quiet; then
|
||||||
|
echo "There are uncommitted changes. Please commit or stash them before running this script."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Ensure that at least one commit argument is passed
|
||||||
|
if [ "$#" -lt 1 ]; then
|
||||||
|
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
commits=("$@")
|
||||||
|
benchmarks_dir=benchmarks
|
||||||
|
|
||||||
|
if ! mkdir -p "${benchmarks_dir}"; then
|
||||||
|
echo "Unable to create dir for benchmarks data"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Benchmark results
|
||||||
|
bench_files=()
|
||||||
|
|
||||||
|
# Run benchmark for each listed commit
|
||||||
|
for i in "${!commits[@]}"; do
|
||||||
|
commit="${commits[i]}"
|
||||||
|
git checkout "$commit" || {
|
||||||
|
echo "Failed to checkout $commit"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sanitized commmit message
|
||||||
|
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
|
||||||
|
|
||||||
|
# Benchmark data will go there
|
||||||
|
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
|
||||||
|
|
||||||
|
if ! go test -bench=. -count=10 >"$bench_file"; then
|
||||||
|
echo "Benchmarking failed for commit $commit"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
bench_files+=("$bench_file")
|
||||||
|
done
|
||||||
|
|
||||||
|
# go install golang.org/x/perf/cmd/benchstat[@latest]
|
||||||
|
benchstat "${bench_files[@]}"
|
|
@ -4,8 +4,10 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
@ -24,18 +26,33 @@ type Query struct {
|
||||||
// https://github.com/jackc/pgx/issues/1380
|
// https://github.com/jackc/pgx/issues/1380
|
||||||
const replacementcharacterwidth = 3
|
const replacementcharacterwidth = 3
|
||||||
|
|
||||||
|
const maxBufSize = 16384 // 16 Ki
|
||||||
|
|
||||||
|
var bufPool = &pool[*bytes.Buffer]{
|
||||||
|
new: func() *bytes.Buffer {
|
||||||
|
return &bytes.Buffer{}
|
||||||
|
},
|
||||||
|
reset: func(b *bytes.Buffer) bool {
|
||||||
|
n := b.Len()
|
||||||
|
b.Reset()
|
||||||
|
return n < maxBufSize
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var null = []byte("null")
|
||||||
|
|
||||||
func (q *Query) Sanitize(args ...any) (string, error) {
|
func (q *Query) Sanitize(args ...any) (string, error) {
|
||||||
argUse := make([]bool, len(args))
|
argUse := make([]bool, len(args))
|
||||||
buf := &bytes.Buffer{}
|
buf := bufPool.get()
|
||||||
|
defer bufPool.put(buf)
|
||||||
|
|
||||||
for _, part := range q.Parts {
|
for _, part := range q.Parts {
|
||||||
var str string
|
|
||||||
switch part := part.(type) {
|
switch part := part.(type) {
|
||||||
case string:
|
case string:
|
||||||
str = part
|
buf.WriteString(part)
|
||||||
case int:
|
case int:
|
||||||
argIdx := part - 1
|
argIdx := part - 1
|
||||||
|
var p []byte
|
||||||
if argIdx < 0 {
|
if argIdx < 0 {
|
||||||
return "", fmt.Errorf("first sql argument must be > 0")
|
return "", fmt.Errorf("first sql argument must be > 0")
|
||||||
}
|
}
|
||||||
|
@ -43,34 +60,41 @@ func (q *Query) Sanitize(args ...any) (string, error) {
|
||||||
if argIdx >= len(args) {
|
if argIdx >= len(args) {
|
||||||
return "", fmt.Errorf("insufficient arguments")
|
return "", fmt.Errorf("insufficient arguments")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prevent SQL injection via Line Comment Creation
|
||||||
|
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||||
|
buf.WriteByte(' ')
|
||||||
|
|
||||||
arg := args[argIdx]
|
arg := args[argIdx]
|
||||||
switch arg := arg.(type) {
|
switch arg := arg.(type) {
|
||||||
case nil:
|
case nil:
|
||||||
str = "null"
|
p = null
|
||||||
case int64:
|
case int64:
|
||||||
str = strconv.FormatInt(arg, 10)
|
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
|
||||||
case float64:
|
case float64:
|
||||||
str = strconv.FormatFloat(arg, 'f', -1, 64)
|
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
|
||||||
case bool:
|
case bool:
|
||||||
str = strconv.FormatBool(arg)
|
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
|
||||||
case []byte:
|
case []byte:
|
||||||
str = QuoteBytes(arg)
|
p = QuoteBytes(buf.AvailableBuffer(), arg)
|
||||||
case string:
|
case string:
|
||||||
str = QuoteString(arg)
|
p = QuoteString(buf.AvailableBuffer(), arg)
|
||||||
case time.Time:
|
case time.Time:
|
||||||
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
|
p = arg.Truncate(time.Microsecond).
|
||||||
|
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("invalid arg type: %T", arg)
|
return "", fmt.Errorf("invalid arg type: %T", arg)
|
||||||
}
|
}
|
||||||
argUse[argIdx] = true
|
argUse[argIdx] = true
|
||||||
|
|
||||||
|
buf.Write(p)
|
||||||
|
|
||||||
// Prevent SQL injection via Line Comment Creation
|
// Prevent SQL injection via Line Comment Creation
|
||||||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||||
str = " " + str + " "
|
buf.WriteByte(' ')
|
||||||
default:
|
default:
|
||||||
return "", fmt.Errorf("invalid Part type: %T", part)
|
return "", fmt.Errorf("invalid Part type: %T", part)
|
||||||
}
|
}
|
||||||
buf.WriteString(str)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, used := range argUse {
|
for i, used := range argUse {
|
||||||
|
@ -82,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewQuery(sql string) (*Query, error) {
|
func NewQuery(sql string) (*Query, error) {
|
||||||
l := &sqlLexer{
|
query := &Query{}
|
||||||
src: sql,
|
query.init(sql)
|
||||||
stateFn: rawState,
|
|
||||||
|
return query, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var sqlLexerPool = &pool[*sqlLexer]{
|
||||||
|
new: func() *sqlLexer {
|
||||||
|
return &sqlLexer{}
|
||||||
|
},
|
||||||
|
reset: func(sl *sqlLexer) bool {
|
||||||
|
*sl = sqlLexer{}
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Query) init(sql string) {
|
||||||
|
parts := q.Parts[:0]
|
||||||
|
if parts == nil {
|
||||||
|
// dirty, but fast heuristic to preallocate for ~90% usecases
|
||||||
|
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
|
||||||
|
parts = make([]Part, 0, n)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
l := sqlLexerPool.get()
|
||||||
|
defer sqlLexerPool.put(l)
|
||||||
|
|
||||||
|
l.src = sql
|
||||||
|
l.stateFn = rawState
|
||||||
|
l.parts = parts
|
||||||
|
|
||||||
for l.stateFn != nil {
|
for l.stateFn != nil {
|
||||||
l.stateFn = l.stateFn(l)
|
l.stateFn = l.stateFn(l)
|
||||||
}
|
}
|
||||||
|
|
||||||
query := &Query{Parts: l.parts}
|
q.Parts = l.parts
|
||||||
|
|
||||||
return query, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func QuoteString(str string) string {
|
func QuoteString(dst []byte, str string) []byte {
|
||||||
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
|
const quote = '\''
|
||||||
|
|
||||||
|
// Preallocate space for the worst case scenario
|
||||||
|
dst = slices.Grow(dst, len(str)*2+2)
|
||||||
|
|
||||||
|
// Add opening quote
|
||||||
|
dst = append(dst, quote)
|
||||||
|
|
||||||
|
// Iterate through the string without allocating
|
||||||
|
for i := 0; i < len(str); i++ {
|
||||||
|
if str[i] == quote {
|
||||||
|
dst = append(dst, quote, quote)
|
||||||
|
} else {
|
||||||
|
dst = append(dst, str[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add closing quote
|
||||||
|
dst = append(dst, quote)
|
||||||
|
|
||||||
|
return dst
|
||||||
}
|
}
|
||||||
|
|
||||||
func QuoteBytes(buf []byte) string {
|
func QuoteBytes(dst, buf []byte) []byte {
|
||||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
if len(buf) == 0 {
|
||||||
|
return append(dst, `'\x'`...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate required length
|
||||||
|
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
|
||||||
|
|
||||||
|
// Ensure dst has enough capacity
|
||||||
|
if cap(dst)-len(dst) < requiredLen {
|
||||||
|
newDst := make([]byte, len(dst), len(dst)+requiredLen)
|
||||||
|
copy(newDst, dst)
|
||||||
|
dst = newDst
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record original length and extend slice
|
||||||
|
origLen := len(dst)
|
||||||
|
dst = dst[:origLen+requiredLen]
|
||||||
|
|
||||||
|
// Add prefix
|
||||||
|
dst[origLen] = '\''
|
||||||
|
dst[origLen+1] = '\\'
|
||||||
|
dst[origLen+2] = 'x'
|
||||||
|
|
||||||
|
// Encode bytes directly into dst
|
||||||
|
hex.Encode(dst[origLen+3:len(dst)-1], buf)
|
||||||
|
|
||||||
|
// Add suffix
|
||||||
|
dst[len(dst)-1] = '\''
|
||||||
|
|
||||||
|
return dst
|
||||||
}
|
}
|
||||||
|
|
||||||
type sqlLexer struct {
|
type sqlLexer struct {
|
||||||
|
@ -319,13 +416,45 @@ func multilineCommentState(l *sqlLexer) stateFn {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var queryPool = &pool[*Query]{
|
||||||
|
new: func() *Query {
|
||||||
|
return &Query{}
|
||||||
|
},
|
||||||
|
reset: func(q *Query) bool {
|
||||||
|
n := len(q.Parts)
|
||||||
|
q.Parts = q.Parts[:0]
|
||||||
|
return n < 64 // drop too large queries
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||||
// as necessary. This function is only safe when standard_conforming_strings is
|
// as necessary. This function is only safe when standard_conforming_strings is
|
||||||
// on.
|
// on.
|
||||||
func SanitizeSQL(sql string, args ...any) (string, error) {
|
func SanitizeSQL(sql string, args ...any) (string, error) {
|
||||||
query, err := NewQuery(sql)
|
query := queryPool.get()
|
||||||
if err != nil {
|
query.init(sql)
|
||||||
return "", err
|
defer queryPool.put(query)
|
||||||
}
|
|
||||||
return query.Sanitize(args...)
|
return query.Sanitize(args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type pool[E any] struct {
|
||||||
|
p sync.Pool
|
||||||
|
new func() E
|
||||||
|
reset func(E) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pool *pool[E]) get() E {
|
||||||
|
v, ok := pool.p.Get().(E)
|
||||||
|
if !ok {
|
||||||
|
v = pool.new()
|
||||||
|
}
|
||||||
|
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *pool[E]) put(v E) {
|
||||||
|
if p.reset(v) {
|
||||||
|
p.p.Put(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
// sanitize_benchmark_test.go
|
||||||
|
package sanitize_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||||
|
)
|
||||||
|
|
||||||
|
var benchmarkSanitizeResult string
|
||||||
|
|
||||||
|
const benchmarkQuery = "" +
|
||||||
|
`SELECT *
|
||||||
|
FROM "water_containers"
|
||||||
|
WHERE NOT "id" = $1 -- int64
|
||||||
|
AND "tags" NOT IN $2 -- nil
|
||||||
|
AND "volume" > $3 -- float64
|
||||||
|
AND "transportable" = $4 -- bool
|
||||||
|
AND position($5 IN "sign") -- bytes
|
||||||
|
AND "label" LIKE $6 -- string
|
||||||
|
AND "created_at" > $7; -- time.Time`
|
||||||
|
|
||||||
|
var benchmarkArgs = []any{
|
||||||
|
int64(12345),
|
||||||
|
nil,
|
||||||
|
float64(500),
|
||||||
|
true,
|
||||||
|
[]byte("8BADF00D"),
|
||||||
|
"kombucha's han'dy awokowa",
|
||||||
|
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkSanitize(b *testing.B) {
|
||||||
|
query, err := sanitize.NewQuery(benchmarkQuery)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("failed to create query: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("failed to sanitize query: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var benchmarkNewSQLResult string
|
||||||
|
|
||||||
|
func BenchmarkSanitizeSQL(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
var err error
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("failed to sanitize SQL: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,55 @@
|
||||||
|
package sanitize_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||||
|
)
|
||||||
|
|
||||||
|
func FuzzQuoteString(f *testing.F) {
|
||||||
|
const prefix = "prefix"
|
||||||
|
f.Add("new\nline")
|
||||||
|
f.Add("sample text")
|
||||||
|
f.Add("sample q'u'o't'e's")
|
||||||
|
f.Add("select 'quoted $42', $1")
|
||||||
|
|
||||||
|
f.Fuzz(func(t *testing.T, input string) {
|
||||||
|
got := string(sanitize.QuoteString([]byte(prefix), input))
|
||||||
|
want := oldQuoteString(input)
|
||||||
|
|
||||||
|
quoted, ok := strings.CutPrefix(got, prefix)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("result has no prefix")
|
||||||
|
}
|
||||||
|
|
||||||
|
if want != quoted {
|
||||||
|
t.Errorf("got %q", got)
|
||||||
|
t.Fatalf("want %q", want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func FuzzQuoteBytes(f *testing.F) {
|
||||||
|
const prefix = "prefix"
|
||||||
|
f.Add([]byte(nil))
|
||||||
|
f.Add([]byte("\n"))
|
||||||
|
f.Add([]byte("sample text"))
|
||||||
|
f.Add([]byte("sample q'u'o't'e's"))
|
||||||
|
f.Add([]byte("select 'quoted $42', $1"))
|
||||||
|
|
||||||
|
f.Fuzz(func(t *testing.T, input []byte) {
|
||||||
|
got := string(sanitize.QuoteBytes([]byte(prefix), input))
|
||||||
|
want := oldQuoteBytes(input)
|
||||||
|
|
||||||
|
quoted, ok := strings.CutPrefix(got, prefix)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("result has no prefix")
|
||||||
|
}
|
||||||
|
|
||||||
|
if want != quoted {
|
||||||
|
t.Errorf("got %q", got)
|
||||||
|
t.Fatalf("want %q", want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -1,6 +1,8 @@
|
||||||
package sanitize_test
|
package sanitize_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -227,3 +229,55 @@ func TestQuerySanitize(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQuoteString(t *testing.T) {
|
||||||
|
tc := func(name, input string) {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := string(sanitize.QuoteString(nil, input))
|
||||||
|
want := oldQuoteString(input)
|
||||||
|
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("got: %s", got)
|
||||||
|
t.Fatalf("want: %s", want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
tc("empty", "")
|
||||||
|
tc("text", "abcd")
|
||||||
|
tc("with quotes", `one's hat is always a cat`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function was used before optimizations.
|
||||||
|
// You should keep for testing purposes - we want to ensure there are no breaking changes.
|
||||||
|
func oldQuoteString(str string) string {
|
||||||
|
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuoteBytes(t *testing.T) {
|
||||||
|
tc := func(name string, input []byte) {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := string(sanitize.QuoteBytes(nil, input))
|
||||||
|
want := oldQuoteBytes(input)
|
||||||
|
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("got: %s", got)
|
||||||
|
t.Fatalf("want: %s", want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
tc("nil", nil)
|
||||||
|
tc("empty", []byte{})
|
||||||
|
tc("text", []byte("abcd"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// This function was used before optimizations.
|
||||||
|
// You should keep for testing purposes - we want to ensure there are no breaking changes.
|
||||||
|
func oldQuoteBytes(buf []byte) string {
|
||||||
|
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||||
|
}
|
||||||
|
|
|
@ -51,6 +51,8 @@ type Config struct {
|
||||||
KerberosSpn string
|
KerberosSpn string
|
||||||
Fallbacks []*FallbackConfig
|
Fallbacks []*FallbackConfig
|
||||||
|
|
||||||
|
SSLNegotiation string // sslnegotiation=postgres or sslnegotiation=direct
|
||||||
|
|
||||||
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
||||||
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
|
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
|
||||||
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
|
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
|
||||||
|
@ -318,6 +320,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
||||||
"sslkey": {},
|
"sslkey": {},
|
||||||
"sslcert": {},
|
"sslcert": {},
|
||||||
"sslrootcert": {},
|
"sslrootcert": {},
|
||||||
|
"sslnegotiation": {},
|
||||||
"sslpassword": {},
|
"sslpassword": {},
|
||||||
"sslsni": {},
|
"sslsni": {},
|
||||||
"krbspn": {},
|
"krbspn": {},
|
||||||
|
@ -386,6 +389,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
||||||
config.Port = fallbacks[0].Port
|
config.Port = fallbacks[0].Port
|
||||||
config.TLSConfig = fallbacks[0].TLSConfig
|
config.TLSConfig = fallbacks[0].TLSConfig
|
||||||
config.Fallbacks = fallbacks[1:]
|
config.Fallbacks = fallbacks[1:]
|
||||||
|
config.SSLNegotiation = settings["sslnegotiation"]
|
||||||
|
|
||||||
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -449,6 +453,7 @@ func parseEnvSettings() map[string]string {
|
||||||
"PGSSLSNI": "sslsni",
|
"PGSSLSNI": "sslsni",
|
||||||
"PGSSLROOTCERT": "sslrootcert",
|
"PGSSLROOTCERT": "sslrootcert",
|
||||||
"PGSSLPASSWORD": "sslpassword",
|
"PGSSLPASSWORD": "sslpassword",
|
||||||
|
"PGSSLNEGOTIATION": "sslnegotiation",
|
||||||
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
||||||
"PGSERVICE": "service",
|
"PGSERVICE": "service",
|
||||||
"PGSERVICEFILE": "servicefile",
|
"PGSERVICEFILE": "servicefile",
|
||||||
|
@ -646,6 +651,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
||||||
sslkey := settings["sslkey"]
|
sslkey := settings["sslkey"]
|
||||||
sslpassword := settings["sslpassword"]
|
sslpassword := settings["sslpassword"]
|
||||||
sslsni := settings["sslsni"]
|
sslsni := settings["sslsni"]
|
||||||
|
sslnegotiation := settings["sslnegotiation"]
|
||||||
|
|
||||||
// Match libpq default behavior
|
// Match libpq default behavior
|
||||||
if sslmode == "" {
|
if sslmode == "" {
|
||||||
|
@ -657,6 +663,13 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
||||||
|
|
||||||
tlsConfig := &tls.Config{}
|
tlsConfig := &tls.Config{}
|
||||||
|
|
||||||
|
if sslnegotiation == "direct" {
|
||||||
|
tlsConfig.NextProtos = []string{"postgresql"}
|
||||||
|
if sslmode == "prefer" {
|
||||||
|
sslmode = "require"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if sslrootcert != "" {
|
if sslrootcert != "" {
|
||||||
var caCertPool *x509.CertPool
|
var caCertPool *x509.CertPool
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ func TestContextWatcherContextCancelled(t *testing.T) {
|
||||||
require.True(t, cleanupCalled, "Cleanup func was not called")
|
require.True(t, cleanupCalled, "Cleanup func was not called")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
|
func TestContextWatcherUnwatchedBeforeContextCancelled(t *testing.T) {
|
||||||
cw := ctxwatch.NewContextWatcher(&testHandler{
|
cw := ctxwatch.NewContextWatcher(&testHandler{
|
||||||
handleCancel: func(context.Context) {
|
handleCancel: func(context.Context) {
|
||||||
t.Error("cancel func should not have been called")
|
t.Error("cancel func should not have been called")
|
||||||
|
|
250
pgconn/pgconn.go
250
pgconn/pgconn.go
|
@ -1,6 +1,7 @@
|
||||||
package pgconn
|
package pgconn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"container/list"
|
||||||
"context"
|
"context"
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
@ -267,12 +268,15 @@ func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []*
|
||||||
|
|
||||||
var pgErr *PgError
|
var pgErr *PgError
|
||||||
if errors.As(err, &pgErr) {
|
if errors.As(err, &pgErr) {
|
||||||
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
|
// pgx will try next host even if libpq does not in certain cases (see #2246)
|
||||||
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
|
// consider change for the next major version
|
||||||
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
|
|
||||||
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
|
const ERRCODE_INVALID_PASSWORD = "28P01"
|
||||||
|
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
|
||||||
|
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
|
||||||
|
|
||||||
|
// auth failed due to invalid password, db does not exist or user has no permission
|
||||||
if pgErr.Code == ERRCODE_INVALID_PASSWORD ||
|
if pgErr.Code == ERRCODE_INVALID_PASSWORD ||
|
||||||
pgErr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && c.tlsConfig != nil ||
|
|
||||||
pgErr.Code == ERRCODE_INVALID_CATALOG_NAME ||
|
pgErr.Code == ERRCODE_INVALID_CATALOG_NAME ||
|
||||||
pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
|
pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
|
||||||
return nil, allErrors
|
return nil, allErrors
|
||||||
|
@ -321,7 +325,15 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo
|
||||||
if connectConfig.tlsConfig != nil {
|
if connectConfig.tlsConfig != nil {
|
||||||
pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn})
|
pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn})
|
||||||
pgConn.contextWatcher.Watch(ctx)
|
pgConn.contextWatcher.Watch(ctx)
|
||||||
tlsConn, err := startTLS(pgConn.conn, connectConfig.tlsConfig)
|
var (
|
||||||
|
tlsConn net.Conn
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if config.SSLNegotiation == "direct" {
|
||||||
|
tlsConn = tls.Client(pgConn.conn, connectConfig.tlsConfig)
|
||||||
|
} else {
|
||||||
|
tlsConn, err = startTLS(pgConn.conn, connectConfig.tlsConfig)
|
||||||
|
}
|
||||||
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
|
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pgConn.conn.Close()
|
pgConn.conn.Close()
|
||||||
|
@ -1408,9 +1420,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
||||||
|
|
||||||
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
|
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
|
||||||
type MultiResultReader struct {
|
type MultiResultReader struct {
|
||||||
pgConn *PgConn
|
pgConn *PgConn
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
pipeline *Pipeline
|
|
||||||
|
|
||||||
rr *ResultReader
|
rr *ResultReader
|
||||||
|
|
||||||
|
@ -1443,12 +1454,8 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case *pgproto3.ReadyForQuery:
|
case *pgproto3.ReadyForQuery:
|
||||||
mrr.closed = true
|
mrr.closed = true
|
||||||
if mrr.pipeline != nil {
|
mrr.pgConn.contextWatcher.Unwatch()
|
||||||
mrr.pipeline.expectedReadyForQueryCount--
|
mrr.pgConn.unlock()
|
||||||
} else {
|
|
||||||
mrr.pgConn.contextWatcher.Unwatch()
|
|
||||||
mrr.pgConn.unlock()
|
|
||||||
}
|
|
||||||
case *pgproto3.ErrorResponse:
|
case *pgproto3.ErrorResponse:
|
||||||
mrr.err = ErrorResponseToPgError(msg)
|
mrr.err = ErrorResponseToPgError(msg)
|
||||||
}
|
}
|
||||||
|
@ -1672,7 +1679,11 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
|
||||||
case *pgproto3.EmptyQueryResponse:
|
case *pgproto3.EmptyQueryResponse:
|
||||||
rr.concludeCommand(CommandTag{}, nil)
|
rr.concludeCommand(CommandTag{}, nil)
|
||||||
case *pgproto3.ErrorResponse:
|
case *pgproto3.ErrorResponse:
|
||||||
rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg))
|
pgErr := ErrorResponseToPgError(msg)
|
||||||
|
if rr.pipeline != nil {
|
||||||
|
rr.pipeline.state.HandleError(pgErr)
|
||||||
|
}
|
||||||
|
rr.concludeCommand(CommandTag{}, pgErr)
|
||||||
}
|
}
|
||||||
|
|
||||||
return msg, nil
|
return msg, nil
|
||||||
|
@ -1773,9 +1784,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
||||||
|
|
||||||
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
|
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
|
||||||
if batch.err != nil {
|
if batch.err != nil {
|
||||||
|
pgConn.contextWatcher.Unwatch()
|
||||||
|
multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err)
|
||||||
multiResult.closed = true
|
multiResult.closed = true
|
||||||
multiResult.err = batch.err
|
pgConn.asyncClose()
|
||||||
pgConn.unlock()
|
|
||||||
return multiResult
|
return multiResult
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1783,9 +1795,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
||||||
defer pgConn.exitPotentialWriteReadDeadlock()
|
defer pgConn.exitPotentialWriteReadDeadlock()
|
||||||
_, err := pgConn.conn.Write(batch.buf)
|
_, err := pgConn.conn.Write(batch.buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
pgConn.contextWatcher.Unwatch()
|
||||||
|
multiResult.err = normalizeTimeoutError(multiResult.ctx, err)
|
||||||
multiResult.closed = true
|
multiResult.closed = true
|
||||||
multiResult.err = err
|
pgConn.asyncClose()
|
||||||
pgConn.unlock()
|
|
||||||
return multiResult
|
return multiResult
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1999,9 +2012,7 @@ type Pipeline struct {
|
||||||
conn *PgConn
|
conn *PgConn
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
|
|
||||||
expectedReadyForQueryCount int
|
state pipelineState
|
||||||
pendingSync bool
|
|
||||||
|
|
||||||
err error
|
err error
|
||||||
closed bool
|
closed bool
|
||||||
}
|
}
|
||||||
|
@ -2012,6 +2023,122 @@ type PipelineSync struct{}
|
||||||
// CloseComplete is returned by GetResults when a CloseComplete message is received.
|
// CloseComplete is returned by GetResults when a CloseComplete message is received.
|
||||||
type CloseComplete struct{}
|
type CloseComplete struct{}
|
||||||
|
|
||||||
|
type pipelineRequestType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
pipelineNil pipelineRequestType = iota
|
||||||
|
pipelinePrepare
|
||||||
|
pipelineQueryParams
|
||||||
|
pipelineQueryPrepared
|
||||||
|
pipelineDeallocate
|
||||||
|
pipelineSyncRequest
|
||||||
|
pipelineFlushRequest
|
||||||
|
)
|
||||||
|
|
||||||
|
type pipelineRequestEvent struct {
|
||||||
|
RequestType pipelineRequestType
|
||||||
|
WasSentToServer bool
|
||||||
|
BeforeFlushOrSync bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type pipelineState struct {
|
||||||
|
requestEventQueue list.List
|
||||||
|
lastRequestType pipelineRequestType
|
||||||
|
pgErr *PgError
|
||||||
|
expectedReadyForQueryCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *pipelineState) Init() {
|
||||||
|
s.requestEventQueue.Init()
|
||||||
|
s.lastRequestType = pipelineNil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *pipelineState) RegisterSendingToServer() {
|
||||||
|
for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() {
|
||||||
|
val := elem.Value.(pipelineRequestEvent)
|
||||||
|
if val.WasSentToServer {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
val.WasSentToServer = true
|
||||||
|
elem.Value = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *pipelineState) registerFlushingBufferOnServer() {
|
||||||
|
for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() {
|
||||||
|
val := elem.Value.(pipelineRequestEvent)
|
||||||
|
if val.BeforeFlushOrSync {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
val.BeforeFlushOrSync = true
|
||||||
|
elem.Value = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *pipelineState) PushBackRequestType(req pipelineRequestType) {
|
||||||
|
if req == pipelineNil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req != pipelineFlushRequest {
|
||||||
|
s.requestEventQueue.PushBack(pipelineRequestEvent{RequestType: req})
|
||||||
|
}
|
||||||
|
if req == pipelineFlushRequest || req == pipelineSyncRequest {
|
||||||
|
s.registerFlushingBufferOnServer()
|
||||||
|
}
|
||||||
|
s.lastRequestType = req
|
||||||
|
|
||||||
|
if req == pipelineSyncRequest {
|
||||||
|
s.expectedReadyForQueryCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType {
|
||||||
|
for {
|
||||||
|
elem := s.requestEventQueue.Front()
|
||||||
|
if elem == nil {
|
||||||
|
return pipelineNil
|
||||||
|
}
|
||||||
|
val := elem.Value.(pipelineRequestEvent)
|
||||||
|
if !(val.WasSentToServer && val.BeforeFlushOrSync) {
|
||||||
|
return pipelineNil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.requestEventQueue.Remove(elem)
|
||||||
|
if val.RequestType == pipelineSyncRequest {
|
||||||
|
s.pgErr = nil
|
||||||
|
}
|
||||||
|
if s.pgErr == nil {
|
||||||
|
return val.RequestType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *pipelineState) HandleError(err *PgError) {
|
||||||
|
s.pgErr = err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *pipelineState) HandleReadyForQuery() {
|
||||||
|
s.expectedReadyForQueryCount--
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *pipelineState) PendingSync() bool {
|
||||||
|
var notPendingSync bool
|
||||||
|
|
||||||
|
if elem := s.requestEventQueue.Back(); elem != nil {
|
||||||
|
val := elem.Value.(pipelineRequestEvent)
|
||||||
|
notPendingSync = (val.RequestType == pipelineSyncRequest) && val.WasSentToServer
|
||||||
|
} else {
|
||||||
|
notPendingSync = (s.lastRequestType == pipelineSyncRequest) || (s.lastRequestType == pipelineNil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return !notPendingSync
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *pipelineState) ExpectedReadyForQuery() int {
|
||||||
|
return s.expectedReadyForQueryCount
|
||||||
|
}
|
||||||
|
|
||||||
// StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent
|
// StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent
|
||||||
// to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection
|
// to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection
|
||||||
// to normal mode. While in pipeline mode, no methods that communicate with the server may be called except
|
// to normal mode. While in pipeline mode, no methods that communicate with the server may be called except
|
||||||
|
@ -2020,16 +2147,21 @@ type CloseComplete struct{}
|
||||||
// Prefer ExecBatch when only sending one group of queries at once.
|
// Prefer ExecBatch when only sending one group of queries at once.
|
||||||
func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline {
|
func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline {
|
||||||
if err := pgConn.lock(); err != nil {
|
if err := pgConn.lock(); err != nil {
|
||||||
return &Pipeline{
|
pipeline := &Pipeline{
|
||||||
closed: true,
|
closed: true,
|
||||||
err: err,
|
err: err,
|
||||||
}
|
}
|
||||||
|
pipeline.state.Init()
|
||||||
|
|
||||||
|
return pipeline
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.pipeline = Pipeline{
|
pgConn.pipeline = Pipeline{
|
||||||
conn: pgConn,
|
conn: pgConn,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
}
|
}
|
||||||
|
pgConn.pipeline.state.Init()
|
||||||
|
|
||||||
pipeline := &pgConn.pipeline
|
pipeline := &pgConn.pipeline
|
||||||
|
|
||||||
if ctx != context.Background() {
|
if ctx != context.Background() {
|
||||||
|
@ -2052,10 +2184,10 @@ func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) {
|
||||||
if p.closed {
|
if p.closed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.pendingSync = true
|
|
||||||
|
|
||||||
p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
|
p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
|
||||||
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
|
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
|
||||||
|
p.state.PushBackRequestType(pipelinePrepare)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendDeallocate deallocates a prepared statement.
|
// SendDeallocate deallocates a prepared statement.
|
||||||
|
@ -2063,9 +2195,9 @@ func (p *Pipeline) SendDeallocate(name string) {
|
||||||
if p.closed {
|
if p.closed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.pendingSync = true
|
|
||||||
|
|
||||||
p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name})
|
p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name})
|
||||||
|
p.state.PushBackRequestType(pipelineDeallocate)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendQueryParams is the pipeline version of *PgConn.QueryParams.
|
// SendQueryParams is the pipeline version of *PgConn.QueryParams.
|
||||||
|
@ -2073,12 +2205,12 @@ func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs [
|
||||||
if p.closed {
|
if p.closed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.pendingSync = true
|
|
||||||
|
|
||||||
p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
|
p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
|
||||||
p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
|
p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
|
||||||
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
|
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
|
||||||
p.conn.frontend.SendExecute(&pgproto3.Execute{})
|
p.conn.frontend.SendExecute(&pgproto3.Execute{})
|
||||||
|
p.state.PushBackRequestType(pipelineQueryParams)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared.
|
// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared.
|
||||||
|
@ -2086,11 +2218,42 @@ func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, para
|
||||||
if p.closed {
|
if p.closed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
p.pendingSync = true
|
|
||||||
|
|
||||||
p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
|
p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
|
||||||
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
|
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
|
||||||
p.conn.frontend.SendExecute(&pgproto3.Execute{})
|
p.conn.frontend.SendExecute(&pgproto3.Execute{})
|
||||||
|
p.state.PushBackRequestType(pipelineQueryPrepared)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendFlushRequest sends a request for the server to flush its output buffer.
|
||||||
|
//
|
||||||
|
// The server flushes its output buffer automatically as a result of Sync being called,
|
||||||
|
// or on any request when not in pipeline mode; this function is useful to cause the server
|
||||||
|
// to flush its output buffer in pipeline mode without establishing a synchronization point.
|
||||||
|
// Note that the request is not itself flushed to the server automatically; use Flush if
|
||||||
|
// necessary. This copies the behavior of libpq PQsendFlushRequest.
|
||||||
|
func (p *Pipeline) SendFlushRequest() {
|
||||||
|
if p.closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.conn.frontend.Send(&pgproto3.Flush{})
|
||||||
|
p.state.PushBackRequestType(pipelineFlushRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendPipelineSync marks a synchronization point in a pipeline by sending a sync message
|
||||||
|
// without flushing the send buffer. This serves as the delimiter of an implicit
|
||||||
|
// transaction and an error recovery point.
|
||||||
|
//
|
||||||
|
// Note that the request is not itself flushed to the server automatically; use Flush if
|
||||||
|
// necessary. This copies the behavior of libpq PQsendPipelineSync.
|
||||||
|
func (p *Pipeline) SendPipelineSync() {
|
||||||
|
if p.closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.conn.frontend.SendSync(&pgproto3.Sync{})
|
||||||
|
p.state.PushBackRequestType(pipelineSyncRequest)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush flushes the queued requests without establishing a synchronization point.
|
// Flush flushes the queued requests without establishing a synchronization point.
|
||||||
|
@ -2115,28 +2278,14 @@ func (p *Pipeline) Flush() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.state.RegisterSendingToServer()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sync establishes a synchronization point and flushes the queued requests.
|
// Sync establishes a synchronization point and flushes the queued requests.
|
||||||
func (p *Pipeline) Sync() error {
|
func (p *Pipeline) Sync() error {
|
||||||
if p.closed {
|
p.SendPipelineSync()
|
||||||
if p.err != nil {
|
return p.Flush()
|
||||||
return p.err
|
|
||||||
}
|
|
||||||
return errors.New("pipeline closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
p.conn.frontend.SendSync(&pgproto3.Sync{})
|
|
||||||
err := p.Flush()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
p.pendingSync = false
|
|
||||||
p.expectedReadyForQueryCount++
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or
|
// GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or
|
||||||
|
@ -2150,7 +2299,7 @@ func (p *Pipeline) GetResults() (results any, err error) {
|
||||||
return nil, errors.New("pipeline closed")
|
return nil, errors.New("pipeline closed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.expectedReadyForQueryCount == 0 {
|
if p.state.ExtractFrontRequestType() == pipelineNil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2195,13 +2344,13 @@ func (p *Pipeline) getResults() (results any, err error) {
|
||||||
case *pgproto3.CloseComplete:
|
case *pgproto3.CloseComplete:
|
||||||
return &CloseComplete{}, nil
|
return &CloseComplete{}, nil
|
||||||
case *pgproto3.ReadyForQuery:
|
case *pgproto3.ReadyForQuery:
|
||||||
p.expectedReadyForQueryCount--
|
p.state.HandleReadyForQuery()
|
||||||
return &PipelineSync{}, nil
|
return &PipelineSync{}, nil
|
||||||
case *pgproto3.ErrorResponse:
|
case *pgproto3.ErrorResponse:
|
||||||
pgErr := ErrorResponseToPgError(msg)
|
pgErr := ErrorResponseToPgError(msg)
|
||||||
|
p.state.HandleError(pgErr)
|
||||||
return nil, pgErr
|
return nil, pgErr
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2231,6 +2380,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {
|
||||||
// These should never happen here. But don't take chances that could lead to a deadlock.
|
// These should never happen here. But don't take chances that could lead to a deadlock.
|
||||||
case *pgproto3.ErrorResponse:
|
case *pgproto3.ErrorResponse:
|
||||||
pgErr := ErrorResponseToPgError(msg)
|
pgErr := ErrorResponseToPgError(msg)
|
||||||
|
p.state.HandleError(pgErr)
|
||||||
return nil, pgErr
|
return nil, pgErr
|
||||||
case *pgproto3.CommandComplete:
|
case *pgproto3.CommandComplete:
|
||||||
p.conn.asyncClose()
|
p.conn.asyncClose()
|
||||||
|
@ -2250,7 +2400,7 @@ func (p *Pipeline) Close() error {
|
||||||
|
|
||||||
p.closed = true
|
p.closed = true
|
||||||
|
|
||||||
if p.pendingSync {
|
if p.state.PendingSync() {
|
||||||
p.conn.asyncClose()
|
p.conn.asyncClose()
|
||||||
p.err = errors.New("pipeline has unsynced requests")
|
p.err = errors.New("pipeline has unsynced requests")
|
||||||
p.conn.contextWatcher.Unwatch()
|
p.conn.contextWatcher.Unwatch()
|
||||||
|
@ -2259,7 +2409,7 @@ func (p *Pipeline) Close() error {
|
||||||
return p.err
|
return p.err
|
||||||
}
|
}
|
||||||
|
|
||||||
for p.expectedReadyForQueryCount > 0 {
|
for p.state.ExpectedReadyForQuery() > 0 {
|
||||||
_, err := p.getResults()
|
_, err := p.getResults()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.err = err
|
p.err = err
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -1420,6 +1421,52 @@ func TestConnExecBatch(t *testing.T) {
|
||||||
assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
|
assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockConnection struct {
|
||||||
|
net.Conn
|
||||||
|
writeLatency *time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockConnection) Write(b []byte) (n int, err error) {
|
||||||
|
time.Sleep(*m.writeLatency)
|
||||||
|
return m.Conn.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnExecBatchWriteError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var mockConn mockConnection
|
||||||
|
writeLatency := 0 * time.Second
|
||||||
|
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
conn, err := net.Dial(network, address)
|
||||||
|
mockConn = mockConnection{conn, &writeLatency}
|
||||||
|
return mockConn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
batch := &pgconn.Batch{}
|
||||||
|
pgConn.Conn()
|
||||||
|
|
||||||
|
ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second)
|
||||||
|
defer cancel2()
|
||||||
|
|
||||||
|
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
|
||||||
|
writeLatency = 2 * time.Second
|
||||||
|
mrr := pgConn.ExecBatch(ctx2, batch)
|
||||||
|
err = mrr.Close()
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||||
|
require.True(t, pgConn.IsClosed())
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnExecBatchDeferredError(t *testing.T) {
|
func TestConnExecBatchDeferredError(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -3105,6 +3152,344 @@ func TestPipelineQueryErrorBetweenSyncs(t *testing.T) {
|
||||||
ensureConnValid(t, pgConn)
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPipelineFlushForSingleRequests(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
pipeline := pgConn.StartPipeline(ctx)
|
||||||
|
|
||||||
|
pipeline.SendPrepare("ps", "select $1::text as msg", nil)
|
||||||
|
pipeline.SendFlushRequest()
|
||||||
|
err = pipeline.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err := pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
sd, ok := results.(*pgconn.StatementDescription)
|
||||||
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
||||||
|
require.Len(t, sd.Fields, 1)
|
||||||
|
require.Equal(t, "msg", string(sd.Fields[0].Name))
|
||||||
|
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, results)
|
||||||
|
|
||||||
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil)
|
||||||
|
pipeline.SendFlushRequest()
|
||||||
|
err = pipeline.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rr, ok := results.(*pgconn.ResultReader)
|
||||||
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||||
|
readResult := rr.Read()
|
||||||
|
require.NoError(t, readResult.Err)
|
||||||
|
require.Len(t, readResult.Rows, 1)
|
||||||
|
require.Len(t, readResult.Rows[0], 1)
|
||||||
|
require.Equal(t, "hello", string(readResult.Rows[0][0]))
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, results)
|
||||||
|
|
||||||
|
pipeline.SendDeallocate("ps")
|
||||||
|
pipeline.SendFlushRequest()
|
||||||
|
err = pipeline.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, ok = results.(*pgconn.CloseComplete)
|
||||||
|
require.Truef(t, ok, "expected CloseComplete, got: %#v", results)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, results)
|
||||||
|
|
||||||
|
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
||||||
|
pipeline.SendFlushRequest()
|
||||||
|
err = pipeline.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rr, ok = results.(*pgconn.ResultReader)
|
||||||
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||||
|
readResult = rr.Read()
|
||||||
|
require.NoError(t, readResult.Err)
|
||||||
|
require.Len(t, readResult.Rows, 1)
|
||||||
|
require.Len(t, readResult.Rows[0], 1)
|
||||||
|
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, results)
|
||||||
|
|
||||||
|
err = pipeline.Sync()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, ok = results.(*pgconn.PipelineSync)
|
||||||
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, results)
|
||||||
|
|
||||||
|
err = pipeline.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPipelineFlushForRequestSeries(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
pipeline := pgConn.StartPipeline(ctx)
|
||||||
|
pipeline.SendPrepare("ps", "select $1::bigint as num", nil)
|
||||||
|
err = pipeline.Sync()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err := pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
sd, ok := results.(*pgconn.StatementDescription)
|
||||||
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
||||||
|
require.Len(t, sd.Fields, 1)
|
||||||
|
require.Equal(t, "num", string(sd.Fields[0].Name))
|
||||||
|
require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, ok = results.(*pgconn.PipelineSync)
|
||||||
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||||
|
|
||||||
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("1")}, nil, nil)
|
||||||
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("2")}, nil, nil)
|
||||||
|
pipeline.SendFlushRequest()
|
||||||
|
err = pipeline.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rr, ok := results.(*pgconn.ResultReader)
|
||||||
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||||
|
readResult := rr.Read()
|
||||||
|
require.NoError(t, readResult.Err)
|
||||||
|
require.Len(t, readResult.Rows, 1)
|
||||||
|
require.Len(t, readResult.Rows[0], 1)
|
||||||
|
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rr, ok = results.(*pgconn.ResultReader)
|
||||||
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||||
|
readResult = rr.Read()
|
||||||
|
require.NoError(t, readResult.Err)
|
||||||
|
require.Len(t, readResult.Rows, 1)
|
||||||
|
require.Len(t, readResult.Rows[0], 1)
|
||||||
|
require.Equal(t, "2", string(readResult.Rows[0][0]))
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, results)
|
||||||
|
|
||||||
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("3")}, nil, nil)
|
||||||
|
err = pipeline.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, results)
|
||||||
|
|
||||||
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("4")}, nil, nil)
|
||||||
|
pipeline.SendFlushRequest()
|
||||||
|
err = pipeline.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rr, ok = results.(*pgconn.ResultReader)
|
||||||
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||||
|
readResult = rr.Read()
|
||||||
|
require.NoError(t, readResult.Err)
|
||||||
|
require.Len(t, readResult.Rows, 1)
|
||||||
|
require.Len(t, readResult.Rows[0], 1)
|
||||||
|
require.Equal(t, "3", string(readResult.Rows[0][0]))
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rr, ok = results.(*pgconn.ResultReader)
|
||||||
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||||
|
readResult = rr.Read()
|
||||||
|
require.NoError(t, readResult.Err)
|
||||||
|
require.Len(t, readResult.Rows, 1)
|
||||||
|
require.Len(t, readResult.Rows[0], 1)
|
||||||
|
require.Equal(t, "4", string(readResult.Rows[0][0]))
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, results)
|
||||||
|
|
||||||
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("5")}, nil, nil)
|
||||||
|
pipeline.SendFlushRequest()
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, results)
|
||||||
|
|
||||||
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("6")}, nil, nil)
|
||||||
|
pipeline.SendFlushRequest()
|
||||||
|
err = pipeline.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rr, ok = results.(*pgconn.ResultReader)
|
||||||
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||||
|
readResult = rr.Read()
|
||||||
|
require.NoError(t, readResult.Err)
|
||||||
|
require.Len(t, readResult.Rows, 1)
|
||||||
|
require.Len(t, readResult.Rows[0], 1)
|
||||||
|
require.Equal(t, "5", string(readResult.Rows[0][0]))
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rr, ok = results.(*pgconn.ResultReader)
|
||||||
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||||
|
readResult = rr.Read()
|
||||||
|
require.NoError(t, readResult.Err)
|
||||||
|
require.Len(t, readResult.Rows, 1)
|
||||||
|
require.Len(t, readResult.Rows[0], 1)
|
||||||
|
require.Equal(t, "6", string(readResult.Rows[0][0]))
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, results)
|
||||||
|
|
||||||
|
err = pipeline.Sync()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, ok = results.(*pgconn.PipelineSync)
|
||||||
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, results)
|
||||||
|
|
||||||
|
err = pipeline.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPipelineFlushWithError(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
pipeline := pgConn.StartPipeline(ctx)
|
||||||
|
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
||||||
|
pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil)
|
||||||
|
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
||||||
|
pipeline.SendFlushRequest()
|
||||||
|
err = pipeline.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err := pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rr, ok := results.(*pgconn.ResultReader)
|
||||||
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||||
|
readResult := rr.Read()
|
||||||
|
require.NoError(t, readResult.Err)
|
||||||
|
require.Len(t, readResult.Rows, 1)
|
||||||
|
require.Len(t, readResult.Rows[0], 1)
|
||||||
|
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rr, ok = results.(*pgconn.ResultReader)
|
||||||
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||||
|
readResult = rr.Read()
|
||||||
|
var pgErr *pgconn.PgError
|
||||||
|
require.ErrorAs(t, readResult.Err, &pgErr)
|
||||||
|
require.Equal(t, "22012", pgErr.Code)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, results)
|
||||||
|
|
||||||
|
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
||||||
|
pipeline.SendFlushRequest()
|
||||||
|
err = pipeline.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, results)
|
||||||
|
|
||||||
|
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
||||||
|
pipeline.SendPipelineSync()
|
||||||
|
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
||||||
|
pipeline.SendFlushRequest()
|
||||||
|
err = pipeline.Flush()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, ok = results.(*pgconn.PipelineSync)
|
||||||
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rr, ok = results.(*pgconn.ResultReader)
|
||||||
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||||
|
readResult = rr.Read()
|
||||||
|
require.NoError(t, readResult.Err)
|
||||||
|
require.Len(t, readResult.Rows, 1)
|
||||||
|
require.Len(t, readResult.Rows[0], 1)
|
||||||
|
require.Equal(t, "5", string(readResult.Rows[0][0]))
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, results)
|
||||||
|
|
||||||
|
err = pipeline.Sync()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
results, err = pipeline.GetResults()
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, ok = results.(*pgconn.PipelineSync)
|
||||||
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||||
|
|
||||||
|
err = pipeline.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ensureConnValid(t, pgConn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestPipelineCloseReadsUnreadResults(t *testing.T) {
|
func TestPipelineCloseReadsUnreadResults(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -3435,6 +3820,173 @@ func TestSNISupport(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnectWithDirectSSLNegotiation(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
connString string
|
||||||
|
expectDirectNego bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Default negotiation (postgres)",
|
||||||
|
connString: "sslmode=require",
|
||||||
|
expectDirectNego: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Direct negotiation",
|
||||||
|
connString: "sslmode=require sslnegotiation=direct",
|
||||||
|
expectDirectNego: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Explicit postgres negotiation",
|
||||||
|
connString: "sslmode=require sslnegotiation=postgres",
|
||||||
|
expectDirectNego: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
script := &pgmock.Script{
|
||||||
|
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(ln.Addr().String())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var directNegoObserved atomic.Bool
|
||||||
|
|
||||||
|
serverErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
defer close(serverErrCh)
|
||||||
|
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("accept error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
|
||||||
|
firstByte := make([]byte, 1)
|
||||||
|
_, err = conn.Read(firstByte)
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("read first byte error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if TLS Client Hello (direct) or PostgreSQL SSLRequest
|
||||||
|
isDirect := firstByte[0] >= 20 && firstByte[0] <= 23
|
||||||
|
directNegoObserved.Store(isDirect)
|
||||||
|
|
||||||
|
var tlsConn *tls.Conn
|
||||||
|
|
||||||
|
if !isDirect {
|
||||||
|
// Handle standard PostgreSQL SSL negotiation
|
||||||
|
// Read the rest of the SSL request message
|
||||||
|
sslRequestRemainder := make([]byte, 7)
|
||||||
|
_, err = io.ReadFull(conn, sslRequestRemainder)
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("read ssl request remainder error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send SSL acceptance response
|
||||||
|
_, err = conn.Write([]byte("S"))
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("write ssl acceptance error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup TLS server without needing to reuse the first byte
|
||||||
|
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("cert error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn = tls.Server(conn, &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// Handle direct TLS negotiation
|
||||||
|
// Setup TLS server with the first byte already read
|
||||||
|
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("cert error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use a wrapper to inject the first byte back into the TLS handshake
|
||||||
|
bufConn := &prefixConn{
|
||||||
|
Conn: conn,
|
||||||
|
prefixData: firstByte,
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn = tls.Server(bufConn, &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete TLS handshake
|
||||||
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("TLS handshake error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer tlsConn.Close()
|
||||||
|
|
||||||
|
err = script.Run(pgproto3.NewBackend(tlsConn, tlsConn))
|
||||||
|
if err != nil {
|
||||||
|
serverErrCh <- fmt.Errorf("pgmock run error: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
connStr := fmt.Sprintf("%s host=localhost port=%s sslmode=require sslinsecure=1",
|
||||||
|
tt.connString, port)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
conn, err := pgconn.Connect(ctx, connStr)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
defer conn.Close(ctx)
|
||||||
|
|
||||||
|
err = <-serverErrCh
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, tt.expectDirectNego, directNegoObserved.Load())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// prefixConn implements a net.Conn that prepends some data to the first Read
|
||||||
|
type prefixConn struct {
|
||||||
|
net.Conn
|
||||||
|
prefixData []byte
|
||||||
|
prefixConsumed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *prefixConn) Read(b []byte) (n int, err error) {
|
||||||
|
if !c.prefixConsumed && len(c.prefixData) > 0 {
|
||||||
|
n = copy(b, c.prefixData)
|
||||||
|
c.prefixData = c.prefixData[n:]
|
||||||
|
c.prefixConsumed = len(c.prefixData) == 0
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
return c.Conn.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
// https://github.com/jackc/pgx/issues/1920
|
// https://github.com/jackc/pgx/issues/1920
|
||||||
func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
|
func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
|
@ -12,7 +12,7 @@ type PasswordMessage struct {
|
||||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
func (*PasswordMessage) Frontend() {}
|
func (*PasswordMessage) Frontend() {}
|
||||||
|
|
||||||
// Frontend identifies this message as an authentication response.
|
// InitialResponse identifies this message as an authentication response.
|
||||||
func (*PasswordMessage) InitialResponse() {}
|
func (*PasswordMessage) InitialResponse() {}
|
||||||
|
|
||||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
|
|
|
@ -11,8 +11,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCompositeCodecTranscode(t *testing.T) {
|
func TestCompositeCodecTranscode(t *testing.T) {
|
||||||
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
|
|
||||||
|
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
_, err := conn.Exec(ctx, `drop type if exists ct_test;
|
_, err := conn.Exec(ctx, `drop type if exists ct_test;
|
||||||
|
@ -91,8 +89,6 @@ func (p *point3d) ScanIndex(i int) any {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCompositeCodecTranscodeStruct(t *testing.T) {
|
func TestCompositeCodecTranscodeStruct(t *testing.T) {
|
||||||
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
|
|
||||||
|
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
_, err := conn.Exec(ctx, `drop type if exists point3d;
|
_, err := conn.Exec(ctx, `drop type if exists point3d;
|
||||||
|
@ -128,8 +124,6 @@ create type point3d as (
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCompositeCodecTranscodeStructWrapper(t *testing.T) {
|
func TestCompositeCodecTranscodeStructWrapper(t *testing.T) {
|
||||||
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
|
|
||||||
|
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
_, err := conn.Exec(ctx, `drop type if exists point3d;
|
_, err := conn.Exec(ctx, `drop type if exists point3d;
|
||||||
|
@ -169,8 +163,6 @@ create type point3d as (
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCompositeCodecDecodeValue(t *testing.T) {
|
func TestCompositeCodecDecodeValue(t *testing.T) {
|
||||||
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
|
|
||||||
|
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
_, err := conn.Exec(ctx, `drop type if exists point3d;
|
_, err := conn.Exec(ctx, `drop type if exists point3d;
|
||||||
|
@ -214,7 +206,7 @@ create type point3d as (
|
||||||
//
|
//
|
||||||
// https://github.com/jackc/pgx/issues/1576
|
// https://github.com/jackc/pgx/issues/1576
|
||||||
func TestCompositeCodecTranscodeStructWrapperForTable(t *testing.T) {
|
func TestCompositeCodecTranscodeStructWrapperForTable(t *testing.T) {
|
||||||
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
|
skipCockroachDB(t, "Server does not support composite types from table definitions")
|
||||||
|
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
|
|
|
@ -53,8 +53,8 @@ similar fashion to database/sql. The second is to use a pointer to a pointer.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
When using nullable pgtype types as parameters for queries, one has to remember
|
When using nullable pgtype types as parameters for queries, one has to remember to explicitly set their Valid field to
|
||||||
to explicitly set their Valid field to true, otherwise the parameter's value will be NULL.
|
true, otherwise the parameter's value will be NULL.
|
||||||
|
|
||||||
JSON Support
|
JSON Support
|
||||||
|
|
||||||
|
@ -159,11 +159,16 @@ example_child_records_test.go for an example.
|
||||||
|
|
||||||
Overview of Scanning Implementation
|
Overview of Scanning Implementation
|
||||||
|
|
||||||
The first step is to use the OID to lookup the correct Codec. If the OID is unavailable, Map will try to find the OID
|
The first step is to use the OID to lookup the correct Codec. The Map will call the Codec's PlanScan method to get a
|
||||||
from previous calls of Map.RegisterDefaultPgType. The Map will call the Codec's PlanScan method to get a plan for
|
plan for scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types
|
||||||
scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types are
|
are interfaces rather than explicit types. For example, PointCodec can use any Go type that implements the PointScanner
|
||||||
interfaces rather than explicit types. For example, PointCodec can use any Go type that implements the PointScanner and
|
and PointValuer interfaces.
|
||||||
PointValuer interfaces.
|
|
||||||
|
If a Go value is not supported directly by a Codec then Map will try see if it is a sql.Scanner. If is then that
|
||||||
|
interface will be used to scan the value. Most sql.Scanners require the input to be in the text format (e.g. UUIDs and
|
||||||
|
numeric). However, pgx will typically have received the value in the binary format. In this case the binary value will be
|
||||||
|
parsed, reencoded as text, and then passed to the sql.Scanner. This may incur additional overhead for query results with
|
||||||
|
a large number of affected values.
|
||||||
|
|
||||||
If a Go value is not supported directly by a Codec then Map will try wrapping it with additional logic and try again.
|
If a Go value is not supported directly by a Codec then Map will try wrapping it with additional logic and try again.
|
||||||
For example, Int8Codec does not support scanning into a renamed type (e.g. type myInt64 int64). But Map will detect that
|
For example, Int8Codec does not support scanning into a renamed type (e.g. type myInt64 int64). But Map will detect that
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
// Do not edit. Generated from pgtype/int.go.erb
|
// Code generated from pgtype/int.go.erb. DO NOT EDIT.
|
||||||
|
|
||||||
package pgtype
|
package pgtype
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
// Do not edit. Generated from pgtype/int_test.go.erb
|
// Code generated from pgtype/int_test.go.erb. DO NOT EDIT.
|
||||||
|
|
||||||
package pgtype_test
|
package pgtype_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
// Code generated from pgtype/integration_benchmark_test.go.erb. DO NOT EDIT.
|
||||||
|
|
||||||
package pgtype_test
|
package pgtype_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
104
pgtype/json.go
104
pgtype/json.go
|
@ -71,6 +71,27 @@ func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Enco
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JSON needs its on scan plan for pointers to handle 'null'::json(b).
|
||||||
|
// Consider making pointerPointerScanPlan more flexible in the future.
|
||||||
|
type jsonPointerScanPlan struct {
|
||||||
|
next ScanPlan
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p jsonPointerScanPlan) Scan(src []byte, dst any) error {
|
||||||
|
el := reflect.ValueOf(dst).Elem()
|
||||||
|
if src == nil || string(src) == "null" {
|
||||||
|
el.SetZero()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
el.Set(reflect.New(el.Type().Elem()))
|
||||||
|
if p.next != nil {
|
||||||
|
return p.next.Scan(src, el.Interface())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type encodePlanJSONCodecEitherFormatString struct{}
|
type encodePlanJSONCodecEitherFormatString struct{}
|
||||||
|
|
||||||
func (encodePlanJSONCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
func (encodePlanJSONCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||||
|
@ -117,58 +138,36 @@ func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (
|
||||||
return buf, nil
|
return buf, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (c *JSONCodec) PlanScan(m *Map, oid uint32, formatCode int16, target any) ScanPlan {
|
||||||
switch target.(type) {
|
return c.planScan(m, oid, formatCode, target, 0)
|
||||||
case *string:
|
|
||||||
return scanPlanAnyToString{}
|
|
||||||
|
|
||||||
case **string:
|
|
||||||
// This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better
|
|
||||||
// solution would be.
|
|
||||||
//
|
|
||||||
// https://github.com/jackc/pgx/issues/1470 -- **string
|
|
||||||
// https://github.com/jackc/pgx/issues/1691 -- ** anything else
|
|
||||||
|
|
||||||
if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok {
|
|
||||||
if nextPlan := m.planScan(oid, format, nextDst, 0); nextPlan != nil {
|
|
||||||
if _, failed := nextPlan.(*scanPlanFail); !failed {
|
|
||||||
wrapperPlan.SetNext(nextPlan)
|
|
||||||
return wrapperPlan
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
case *[]byte:
|
|
||||||
return scanPlanJSONToByteSlice{}
|
|
||||||
case BytesScanner:
|
|
||||||
return scanPlanBinaryBytesToBytesScanner{}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence.
|
|
||||||
//
|
|
||||||
// https://github.com/jackc/pgx/issues/1418
|
|
||||||
if isSQLScanner(target) {
|
|
||||||
return &scanPlanSQLScanner{formatCode: format}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &scanPlanJSONToJSONUnmarshal{
|
|
||||||
unmarshal: c.Unmarshal,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// we need to check if the target is a pointer to a sql.Scanner (or any of the pointer ref tree implements a sql.Scanner).
|
// JSON cannot fallback to pointerPointerScanPlan because of 'null'::json(b),
|
||||||
//
|
// so we need to duplicate the logic here.
|
||||||
// https://github.com/jackc/pgx/issues/2146
|
func (c *JSONCodec) planScan(m *Map, oid uint32, formatCode int16, target any, depth int) ScanPlan {
|
||||||
func isSQLScanner(v any) bool {
|
if depth > 8 {
|
||||||
val := reflect.ValueOf(v)
|
return &scanPlanFail{m: m, oid: oid, formatCode: formatCode}
|
||||||
for val.Kind() == reflect.Ptr {
|
}
|
||||||
if _, ok := val.Interface().(sql.Scanner); ok {
|
|
||||||
return true
|
switch target.(type) {
|
||||||
}
|
case *string:
|
||||||
val = val.Elem()
|
return &scanPlanAnyToString{}
|
||||||
|
case *[]byte:
|
||||||
|
return &scanPlanJSONToByteSlice{}
|
||||||
|
case BytesScanner:
|
||||||
|
return &scanPlanBinaryBytesToBytesScanner{}
|
||||||
|
case sql.Scanner:
|
||||||
|
return &scanPlanSQLScanner{formatCode: formatCode}
|
||||||
|
}
|
||||||
|
|
||||||
|
rv := reflect.ValueOf(target)
|
||||||
|
if rv.Kind() == reflect.Pointer && rv.Elem().Kind() == reflect.Pointer {
|
||||||
|
var plan jsonPointerScanPlan
|
||||||
|
plan.next = c.planScan(m, oid, formatCode, rv.Elem().Interface(), depth+1)
|
||||||
|
return plan
|
||||||
|
} else {
|
||||||
|
return &scanPlanJSONToJSONUnmarshal{unmarshal: c.Unmarshal}
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type scanPlanAnyToString struct{}
|
type scanPlanAnyToString struct{}
|
||||||
|
@ -212,7 +211,12 @@ func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
|
||||||
return fmt.Errorf("cannot scan NULL into %T", dst)
|
return fmt.Errorf("cannot scan NULL into %T", dst)
|
||||||
}
|
}
|
||||||
|
|
||||||
elem := reflect.ValueOf(dst).Elem()
|
v := reflect.ValueOf(dst)
|
||||||
|
if v.Kind() != reflect.Pointer || v.IsNil() {
|
||||||
|
return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst)
|
||||||
|
}
|
||||||
|
|
||||||
|
elem := v.Elem()
|
||||||
elem.Set(reflect.Zero(elem.Type()))
|
elem.Set(reflect.Zero(elem.Type()))
|
||||||
|
|
||||||
return s.unmarshal(src, dst)
|
return s.unmarshal(src, dst)
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -48,6 +49,7 @@ func TestJSONCodec(t *testing.T) {
|
||||||
Age int `json:"age"`
|
Age int `json:"age"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var str string
|
||||||
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "json", []pgxtest.ValueRoundTripTest{
|
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "json", []pgxtest.ValueRoundTripTest{
|
||||||
{nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))},
|
{nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))},
|
||||||
{map[string]any(nil), new(*string), isExpectedEq((*string)(nil))},
|
{map[string]any(nil), new(*string), isExpectedEq((*string)(nil))},
|
||||||
|
@ -65,6 +67,9 @@ func TestJSONCodec(t *testing.T) {
|
||||||
{Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))},
|
{Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))},
|
||||||
// Test driver.Scanner is used before json.Unmarshaler (https://github.com/jackc/pgx/issues/2146)
|
// Test driver.Scanner is used before json.Unmarshaler (https://github.com/jackc/pgx/issues/2146)
|
||||||
{Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))},
|
{Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))},
|
||||||
|
|
||||||
|
// Test driver.Scanner without pointer receiver (https://github.com/jackc/pgx/issues/2204)
|
||||||
|
{NonPointerJSONScanner{V: stringPtr("{}")}, NonPointerJSONScanner{V: &str}, func(a any) bool { return str == "{}" }},
|
||||||
})
|
})
|
||||||
|
|
||||||
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
||||||
|
@ -136,6 +141,27 @@ func (i Issue2146) Value() (driver.Value, error) {
|
||||||
return string(b), err
|
return string(b), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NonPointerJSONScanner struct {
|
||||||
|
V *string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i NonPointerJSONScanner) Scan(src any) error {
|
||||||
|
switch c := src.(type) {
|
||||||
|
case string:
|
||||||
|
*i.V = c
|
||||||
|
case []byte:
|
||||||
|
*i.V = string(c)
|
||||||
|
default:
|
||||||
|
return errors.New("unknown source type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i NonPointerJSONScanner) Value() (driver.Value, error) {
|
||||||
|
return i.V, nil
|
||||||
|
}
|
||||||
|
|
||||||
// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648
|
// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648
|
||||||
func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
|
func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
@ -166,11 +192,15 @@ func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
|
||||||
// A string cannot scan a NULL.
|
// A string cannot scan a NULL.
|
||||||
str := "foobar"
|
str := "foobar"
|
||||||
err = conn.QueryRow(ctx, "select null::json").Scan(&str)
|
err = conn.QueryRow(ctx, "select null::json").Scan(&str)
|
||||||
require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string")
|
fieldName := "json"
|
||||||
|
if conn.PgConn().ParameterStatus("crdb_version") != "" {
|
||||||
|
fieldName = "jsonb" // Seems like CockroachDB treats json as jsonb.
|
||||||
|
}
|
||||||
|
require.EqualError(t, err, fmt.Sprintf("can't scan into dest[0] (col: %s): cannot scan NULL into *string", fieldName))
|
||||||
|
|
||||||
// A non-string cannot scan a NULL.
|
// A non-string cannot scan a NULL.
|
||||||
err = conn.QueryRow(ctx, "select null::json").Scan(&n)
|
err = conn.QueryRow(ctx, "select null::json").Scan(&n)
|
||||||
require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *int")
|
require.EqualError(t, err, fmt.Sprintf("can't scan into dest[0] (col: %s): cannot scan NULL into *int", fieldName))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -267,7 +297,8 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
|
||||||
Unmarshal: func(data []byte, v any) error {
|
Unmarshal: func(data []byte, v any) error {
|
||||||
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
|
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
|
||||||
},
|
},
|
||||||
}})
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
||||||
|
@ -278,3 +309,54 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
|
||||||
}},
|
}},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJSONCodecScanToNonPointerValues(t *testing.T) {
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
n := 44
|
||||||
|
err := conn.QueryRow(ctx, "select '42'::jsonb").Scan(n)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var i *int
|
||||||
|
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(i)
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
m := 0
|
||||||
|
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(&m)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 42, m)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONCodecScanNull(t *testing.T) {
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
var dest struct{}
|
||||||
|
err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "cannot scan NULL into *struct {}")
|
||||||
|
|
||||||
|
err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&dest)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var destPointer *struct{}
|
||||||
|
err = conn.QueryRow(ctx, "select null::jsonb").Scan(&destPointer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, destPointer)
|
||||||
|
|
||||||
|
err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&destPointer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, destPointer)
|
||||||
|
|
||||||
|
var raw json.RawMessage
|
||||||
|
require.NoError(t, conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&raw))
|
||||||
|
require.Equal(t, json.RawMessage("null"), raw)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONCodecScanNullToPointerToSQLScanner(t *testing.T) {
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
var dest *Issue2146
|
||||||
|
err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Nil(t, dest)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -66,11 +66,11 @@ func TestJSONBCodecUnmarshalSQLNull(t *testing.T) {
|
||||||
// A string cannot scan a NULL.
|
// A string cannot scan a NULL.
|
||||||
str := "foobar"
|
str := "foobar"
|
||||||
err = conn.QueryRow(ctx, "select null::jsonb").Scan(&str)
|
err = conn.QueryRow(ctx, "select null::jsonb").Scan(&str)
|
||||||
require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string")
|
require.EqualError(t, err, "can't scan into dest[0] (col: jsonb): cannot scan NULL into *string")
|
||||||
|
|
||||||
// A non-string cannot scan a NULL.
|
// A non-string cannot scan a NULL.
|
||||||
err = conn.QueryRow(ctx, "select null::jsonb").Scan(&n)
|
err = conn.QueryRow(ctx, "select null::jsonb").Scan(&n)
|
||||||
require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *int")
|
require.EqualError(t, err, "can't scan into dest[0] (col: jsonb): cannot scan NULL into *int")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -396,11 +396,7 @@ type scanPlanSQLScanner struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
|
func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
|
||||||
scanner := getSQLScanner(dst)
|
scanner := dst.(sql.Scanner)
|
||||||
|
|
||||||
if scanner == nil {
|
|
||||||
return fmt.Errorf("cannot scan into %T", dst)
|
|
||||||
}
|
|
||||||
|
|
||||||
if src == nil {
|
if src == nil {
|
||||||
// This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the
|
// This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the
|
||||||
|
@ -413,21 +409,6 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
|
|
||||||
func getSQLScanner(target any) sql.Scanner {
|
|
||||||
val := reflect.ValueOf(target)
|
|
||||||
for val.Kind() == reflect.Ptr {
|
|
||||||
if _, ok := val.Interface().(sql.Scanner); ok {
|
|
||||||
if val.IsNil() {
|
|
||||||
val.Set(reflect.New(val.Type().Elem()))
|
|
||||||
}
|
|
||||||
return val.Interface().(sql.Scanner)
|
|
||||||
}
|
|
||||||
val = val.Elem()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type scanPlanString struct{}
|
type scanPlanString struct{}
|
||||||
|
|
||||||
func (scanPlanString) Scan(src []byte, dst any) error {
|
func (scanPlanString) Scan(src []byte, dst any) error {
|
||||||
|
|
|
@ -91,7 +91,25 @@ func initDefaultMap() {
|
||||||
defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}})
|
defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}})
|
||||||
defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}})
|
defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}})
|
||||||
defaultMap.RegisterType(&Type{Name: "xid8", OID: XID8OID, Codec: Uint64Codec{}})
|
defaultMap.RegisterType(&Type{Name: "xid8", OID: XID8OID, Codec: Uint64Codec{}})
|
||||||
defaultMap.RegisterType(&Type{Name: "xml", OID: XMLOID, Codec: &XMLCodec{Marshal: xml.Marshal, Unmarshal: xml.Unmarshal}})
|
defaultMap.RegisterType(&Type{Name: "xml", OID: XMLOID, Codec: &XMLCodec{
|
||||||
|
Marshal: xml.Marshal,
|
||||||
|
// xml.Unmarshal does not support unmarshalling into *any. However, XMLCodec.DecodeValue calls Unmarshal with a
|
||||||
|
// *any. Wrap xml.Marshal with a function that copies the data into a new byte slice in this case. Not implementing
|
||||||
|
// directly in XMLCodec.DecodeValue to allow for the unlikely possibility that someone uses an alternative XML
|
||||||
|
// unmarshaler that does support unmarshalling into *any.
|
||||||
|
//
|
||||||
|
// https://github.com/jackc/pgx/issues/2227
|
||||||
|
// https://github.com/jackc/pgx/pull/2228
|
||||||
|
Unmarshal: func(data []byte, v any) error {
|
||||||
|
if v, ok := v.(*any); ok {
|
||||||
|
dstBuf := make([]byte, len(data))
|
||||||
|
copy(dstBuf, data)
|
||||||
|
*v = dstBuf
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return xml.Unmarshal(data, v)
|
||||||
|
},
|
||||||
|
}})
|
||||||
|
|
||||||
// Range types
|
// Range types
|
||||||
defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}})
|
defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}})
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const pgTimestampFormat = "2006-01-02 15:04:05.999999999"
|
const pgTimestampFormat = "2006-01-02 15:04:05.999999999"
|
||||||
|
const jsonISO8601 = "2006-01-02T15:04:05.999999999"
|
||||||
|
|
||||||
type TimestampScanner interface {
|
type TimestampScanner interface {
|
||||||
ScanTimestamp(v Timestamp) error
|
ScanTimestamp(v Timestamp) error
|
||||||
|
@ -76,7 +77,7 @@ func (ts Timestamp) MarshalJSON() ([]byte, error) {
|
||||||
|
|
||||||
switch ts.InfinityModifier {
|
switch ts.InfinityModifier {
|
||||||
case Finite:
|
case Finite:
|
||||||
s = ts.Time.Format(time.RFC3339Nano)
|
s = ts.Time.Format(jsonISO8601)
|
||||||
case Infinity:
|
case Infinity:
|
||||||
s = "infinity"
|
s = "infinity"
|
||||||
case NegativeInfinity:
|
case NegativeInfinity:
|
||||||
|
@ -104,15 +105,23 @@ func (ts *Timestamp) UnmarshalJSON(b []byte) error {
|
||||||
case "-infinity":
|
case "-infinity":
|
||||||
*ts = Timestamp{Valid: true, InfinityModifier: -Infinity}
|
*ts = Timestamp{Valid: true, InfinityModifier: -Infinity}
|
||||||
default:
|
default:
|
||||||
// PostgreSQL uses ISO 8601 wihout timezone for to_json function and casting from a string to timestampt
|
// Parse time with or without timezonr
|
||||||
tim, err := time.Parse(time.RFC3339Nano, *s+"Z")
|
tss := *s
|
||||||
if err != nil {
|
// PostgreSQL uses ISO 8601 without timezone for to_json function and casting from a string to timestampt
|
||||||
return err
|
tim, err := time.Parse(time.RFC3339Nano, tss)
|
||||||
|
if err == nil {
|
||||||
|
*ts = Timestamp{Time: tim, Valid: true}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
tim, err = time.ParseInLocation(jsonISO8601, tss, time.UTC)
|
||||||
*ts = Timestamp{Time: tim, Valid: true}
|
if err == nil {
|
||||||
|
*ts = Timestamp{Time: tim, Valid: true}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
ts.Valid = false
|
||||||
|
return fmt.Errorf("cannot unmarshal %s to timestamp with layout %s or %s (%w)",
|
||||||
|
*s, time.RFC3339Nano, jsonISO8601, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,12 +2,14 @@ package pgtype_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
pgx "github.com/jackc/pgx/v5"
|
pgx "github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
"github.com/jackc/pgx/v5/pgxtest"
|
"github.com/jackc/pgx/v5/pgxtest"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -100,13 +102,24 @@ func TestTimestampCodecDecodeTextInvalid(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTimestampMarshalJSON(t *testing.T) {
|
func TestTimestampMarshalJSON(t *testing.T) {
|
||||||
|
|
||||||
|
tsStruct := struct {
|
||||||
|
TS pgtype.Timestamp `json:"ts"`
|
||||||
|
}{}
|
||||||
|
|
||||||
|
tm := time.Date(2012, 3, 29, 10, 5, 45, 0, time.UTC)
|
||||||
|
tsString := "\"" + tm.Format("2006-01-02T15:04:05") + "\"" // `"2012-03-29T10:05:45"`
|
||||||
|
var pgt pgtype.Timestamp
|
||||||
|
_ = pgt.Scan(tm)
|
||||||
|
|
||||||
successfulTests := []struct {
|
successfulTests := []struct {
|
||||||
source pgtype.Timestamp
|
source pgtype.Timestamp
|
||||||
result string
|
result string
|
||||||
}{
|
}{
|
||||||
{source: pgtype.Timestamp{}, result: "null"},
|
{source: pgtype.Timestamp{}, result: "null"},
|
||||||
{source: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.UTC), Valid: true}, result: "\"2012-03-29T10:05:45Z\""},
|
{source: pgtype.Timestamp{Time: tm, Valid: true}, result: tsString},
|
||||||
{source: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.UTC), Valid: true}, result: "\"2012-03-29T10:05:45.555Z\""},
|
{source: pgt, result: tsString},
|
||||||
|
{source: pgtype.Timestamp{Time: tm.Add(time.Second * 555 / 1000), Valid: true}, result: `"2012-03-29T10:05:45.555"`},
|
||||||
{source: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true}, result: "\"infinity\""},
|
{source: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true}, result: "\"infinity\""},
|
||||||
{source: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "\"-infinity\""},
|
{source: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "\"-infinity\""},
|
||||||
}
|
}
|
||||||
|
@ -116,12 +129,32 @@ func TestTimestampMarshalJSON(t *testing.T) {
|
||||||
t.Errorf("%d: %v", i, err)
|
t.Errorf("%d: %v", i, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if string(r) != tt.result {
|
if !assert.Equal(t, tt.result, string(r)) {
|
||||||
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r))
|
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r))
|
||||||
}
|
}
|
||||||
|
tsStruct.TS = tt.source
|
||||||
|
b, err := json.Marshal(tsStruct)
|
||||||
|
assert.NoErrorf(t, err, "failed to marshal %v %s", tt.source, err)
|
||||||
|
t2 := tsStruct
|
||||||
|
t2.TS = pgtype.Timestamp{} // Clear out the value so that we can compare after unmarshalling
|
||||||
|
err = json.Unmarshal(b, &t2)
|
||||||
|
assert.NoErrorf(t, err, "failed to unmarshal %v with %s", tt.source, err)
|
||||||
|
assert.True(t, tsStruct.TS.Time.Unix() == t2.TS.Time.Unix())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTimestampUnmarshalJSONErrors(t *testing.T) {
|
||||||
|
tsStruct := struct {
|
||||||
|
TS pgtype.Timestamp `json:"ts"`
|
||||||
|
}{}
|
||||||
|
goodJson1 := []byte(`{"ts":"2012-03-29T10:05:45"}`)
|
||||||
|
assert.NoError(t, json.Unmarshal(goodJson1, &tsStruct))
|
||||||
|
goodJson2 := []byte(`{"ts":"2012-03-29T10:05:45Z"}`)
|
||||||
|
assert.NoError(t, json.Unmarshal(goodJson2, &tsStruct))
|
||||||
|
badJson := []byte(`{"ts":"2012-03-29"}`)
|
||||||
|
assert.Error(t, json.Unmarshal(badJson, &tsStruct))
|
||||||
|
}
|
||||||
|
|
||||||
func TestTimestampUnmarshalJSON(t *testing.T) {
|
func TestTimestampUnmarshalJSON(t *testing.T) {
|
||||||
successfulTests := []struct {
|
successfulTests := []struct {
|
||||||
source string
|
source string
|
||||||
|
|
|
@ -79,7 +79,7 @@ func TestXMLCodecUnmarshalSQLNull(t *testing.T) {
|
||||||
// A string cannot scan a NULL.
|
// A string cannot scan a NULL.
|
||||||
str := "foobar"
|
str := "foobar"
|
||||||
err = conn.QueryRow(ctx, "select null::xml").Scan(&str)
|
err = conn.QueryRow(ctx, "select null::xml").Scan(&str)
|
||||||
assert.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string")
|
assert.EqualError(t, err, "can't scan into dest[0] (col: xml): cannot scan NULL into *string")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,3 +97,32 @@ func TestXMLCodecPointerToPointerToString(t *testing.T) {
|
||||||
require.Nil(t, s)
|
require.Nil(t, s)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestXMLCodecDecodeValue(t *testing.T) {
|
||||||
|
skipCockroachDB(t, "CockroachDB does not support XML.")
|
||||||
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) {
|
||||||
|
for _, tt := range []struct {
|
||||||
|
sql string
|
||||||
|
expected any
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
sql: `select '<foo>bar</foo>'::xml`,
|
||||||
|
expected: []byte("<foo>bar</foo>"),
|
||||||
|
},
|
||||||
|
} {
|
||||||
|
t.Run(tt.sql, func(t *testing.T) {
|
||||||
|
rows, err := conn.Query(ctx, tt.sql)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
values, err := rows.Values()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, values, 1)
|
||||||
|
require.Equal(t, tt.expected, values[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, rows.Err())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
// Do not edit. Generated from pgtype/zeronull/int.go.erb
|
// Code generated from pgtype/zeronull/int.go.erb. DO NOT EDIT.
|
||||||
|
|
||||||
package zeronull
|
package zeronull
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
// Do not edit. Generated from pgtype/zeronull/int_test.go.erb
|
// Code generated from pgtype/zeronull/int_test.go.erb. DO NOT EDIT.
|
||||||
|
|
||||||
package zeronull_test
|
package zeronull_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|
|
@ -147,6 +147,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName
|
||||||
assert.Equalf(t, expected.MaxConnIdleTime, actual.MaxConnIdleTime, "%s - MaxConnIdleTime", testName)
|
assert.Equalf(t, expected.MaxConnIdleTime, actual.MaxConnIdleTime, "%s - MaxConnIdleTime", testName)
|
||||||
assert.Equalf(t, expected.MaxConns, actual.MaxConns, "%s - MaxConns", testName)
|
assert.Equalf(t, expected.MaxConns, actual.MaxConns, "%s - MaxConns", testName)
|
||||||
assert.Equalf(t, expected.MinConns, actual.MinConns, "%s - MinConns", testName)
|
assert.Equalf(t, expected.MinConns, actual.MinConns, "%s - MinConns", testName)
|
||||||
|
assert.Equalf(t, expected.MinIdleConns, actual.MinIdleConns, "%s - MinIdleConns", testName)
|
||||||
assert.Equalf(t, expected.HealthCheckPeriod, actual.HealthCheckPeriod, "%s - HealthCheckPeriod", testName)
|
assert.Equalf(t, expected.HealthCheckPeriod, actual.HealthCheckPeriod, "%s - HealthCheckPeriod", testName)
|
||||||
|
|
||||||
assertConnConfigsEqual(t, expected.ConnConfig, actual.ConnConfig, testName)
|
assertConnConfigsEqual(t, expected.ConnConfig, actual.ConnConfig, testName)
|
||||||
|
|
|
@ -17,6 +17,7 @@ import (
|
||||||
|
|
||||||
var defaultMaxConns = int32(4)
|
var defaultMaxConns = int32(4)
|
||||||
var defaultMinConns = int32(0)
|
var defaultMinConns = int32(0)
|
||||||
|
var defaultMinIdleConns = int32(0)
|
||||||
var defaultMaxConnLifetime = time.Hour
|
var defaultMaxConnLifetime = time.Hour
|
||||||
var defaultMaxConnIdleTime = time.Minute * 30
|
var defaultMaxConnIdleTime = time.Minute * 30
|
||||||
var defaultHealthCheckPeriod = time.Minute
|
var defaultHealthCheckPeriod = time.Minute
|
||||||
|
@ -87,6 +88,7 @@ type Pool struct {
|
||||||
afterRelease func(*pgx.Conn) bool
|
afterRelease func(*pgx.Conn) bool
|
||||||
beforeClose func(*pgx.Conn)
|
beforeClose func(*pgx.Conn)
|
||||||
minConns int32
|
minConns int32
|
||||||
|
minIdleConns int32
|
||||||
maxConns int32
|
maxConns int32
|
||||||
maxConnLifetime time.Duration
|
maxConnLifetime time.Duration
|
||||||
maxConnLifetimeJitter time.Duration
|
maxConnLifetimeJitter time.Duration
|
||||||
|
@ -144,6 +146,13 @@ type Config struct {
|
||||||
// to create new connections.
|
// to create new connections.
|
||||||
MinConns int32
|
MinConns int32
|
||||||
|
|
||||||
|
// MinIdleConns is the minimum number of idle connections in the pool. You can increase this to ensure that
|
||||||
|
// there are always idle connections available. This can help reduce tail latencies during request processing,
|
||||||
|
// as you can avoid the latency of establishing a new connection while handling requests. It is superior
|
||||||
|
// to MinConns for this purpose.
|
||||||
|
// Similar to MinConns, the pool might temporarily dip below MinIdleConns after connection closes.
|
||||||
|
MinIdleConns int32
|
||||||
|
|
||||||
// HealthCheckPeriod is the duration between checks of the health of idle connections.
|
// HealthCheckPeriod is the duration between checks of the health of idle connections.
|
||||||
HealthCheckPeriod time.Duration
|
HealthCheckPeriod time.Duration
|
||||||
|
|
||||||
|
@ -189,6 +198,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
|
||||||
afterRelease: config.AfterRelease,
|
afterRelease: config.AfterRelease,
|
||||||
beforeClose: config.BeforeClose,
|
beforeClose: config.BeforeClose,
|
||||||
minConns: config.MinConns,
|
minConns: config.MinConns,
|
||||||
|
minIdleConns: config.MinIdleConns,
|
||||||
maxConns: config.MaxConns,
|
maxConns: config.MaxConns,
|
||||||
maxConnLifetime: config.MaxConnLifetime,
|
maxConnLifetime: config.MaxConnLifetime,
|
||||||
maxConnLifetimeJitter: config.MaxConnLifetimeJitter,
|
maxConnLifetimeJitter: config.MaxConnLifetimeJitter,
|
||||||
|
@ -271,7 +281,8 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
p.createIdleResources(ctx, int(p.minConns))
|
targetIdleResources := max(int(p.minConns), int(p.minIdleConns))
|
||||||
|
p.createIdleResources(ctx, targetIdleResources)
|
||||||
p.backgroundHealthCheck()
|
p.backgroundHealthCheck()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -334,6 +345,17 @@ func ParseConfig(connString string) (*Config, error) {
|
||||||
config.MinConns = defaultMinConns
|
config.MinConns = defaultMinConns
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s, ok := config.ConnConfig.Config.RuntimeParams["pool_min_idle_conns"]; ok {
|
||||||
|
delete(connConfig.Config.RuntimeParams, "pool_min_idle_conns")
|
||||||
|
n, err := strconv.ParseInt(s, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot parse pool_min_idle_conns: %w", err)
|
||||||
|
}
|
||||||
|
config.MinIdleConns = int32(n)
|
||||||
|
} else {
|
||||||
|
config.MinIdleConns = defaultMinIdleConns
|
||||||
|
}
|
||||||
|
|
||||||
if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime"]; ok {
|
if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime"]; ok {
|
||||||
delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime")
|
delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime")
|
||||||
d, err := time.ParseDuration(s)
|
d, err := time.ParseDuration(s)
|
||||||
|
@ -472,7 +494,9 @@ func (p *Pool) checkMinConns() error {
|
||||||
// TotalConns can include ones that are being destroyed but we should have
|
// TotalConns can include ones that are being destroyed but we should have
|
||||||
// sleep(500ms) around all of the destroys to help prevent that from throwing
|
// sleep(500ms) around all of the destroys to help prevent that from throwing
|
||||||
// off this check
|
// off this check
|
||||||
toCreate := p.minConns - p.Stat().TotalConns()
|
|
||||||
|
// Create the number of connections needed to get to both minConns and minIdleConns
|
||||||
|
toCreate := max(p.minConns-p.Stat().TotalConns(), p.minIdleConns-p.Stat().IdleConns())
|
||||||
if toCreate > 0 {
|
if toCreate > 0 {
|
||||||
return p.createIdleResources(context.Background(), int(toCreate))
|
return p.createIdleResources(context.Background(), int(toCreate))
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,10 +43,11 @@ func TestConnectConfig(t *testing.T) {
|
||||||
func TestParseConfigExtractsPoolArguments(t *testing.T) {
|
func TestParseConfigExtractsPoolArguments(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
config, err := pgxpool.ParseConfig("pool_max_conns=42 pool_min_conns=1")
|
config, err := pgxpool.ParseConfig("pool_max_conns=42 pool_min_conns=1 pool_min_idle_conns=2")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.EqualValues(t, 42, config.MaxConns)
|
assert.EqualValues(t, 42, config.MaxConns)
|
||||||
assert.EqualValues(t, 1, config.MinConns)
|
assert.EqualValues(t, 1, config.MinConns)
|
||||||
|
assert.EqualValues(t, 2, config.MinIdleConns)
|
||||||
assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_max_conns")
|
assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_max_conns")
|
||||||
assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_min_conns")
|
assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_min_conns")
|
||||||
}
|
}
|
||||||
|
|
|
@ -82,3 +82,10 @@ func (s *Stat) MaxLifetimeDestroyCount() int64 {
|
||||||
func (s *Stat) MaxIdleDestroyCount() int64 {
|
func (s *Stat) MaxIdleDestroyCount() int64 {
|
||||||
return s.idleDestroyCount
|
return s.idleDestroyCount
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EmptyAcquireWaitTime returns the cumulative time waited for successful acquires
|
||||||
|
// from the pool for a resource to be released or constructed because the pool was
|
||||||
|
// empty.
|
||||||
|
func (s *Stat) EmptyAcquireWaitTime() time.Duration {
|
||||||
|
return s.s.EmptyAcquireWaitTime()
|
||||||
|
}
|
||||||
|
|
|
@ -420,7 +420,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) {
|
||||||
t.Fatal("Expected Rows to have an error after an improper read but it didn't")
|
t.Fatal("Expected Rows to have an error after an improper read but it didn't")
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows.Err().Error() != "can't scan into dest[0]: cannot scan int4 (OID 23) in binary format into *time.Time" {
|
if rows.Err().Error() != "can't scan into dest[0] (col: n): cannot scan int4 (OID 23) in binary format into *time.Time" {
|
||||||
t.Fatalf("Expected different Rows.Err(): %v", rows.Err())
|
t.Fatalf("Expected different Rows.Err(): %v", rows.Err())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
13
rows.go
13
rows.go
|
@ -272,7 +272,7 @@ func (rows *baseRows) Scan(dest ...any) error {
|
||||||
|
|
||||||
err := rows.scanPlans[i].Scan(values[i], dst)
|
err := rows.scanPlans[i].Scan(values[i], dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = ScanArgError{ColumnIndex: i, Err: err}
|
err = ScanArgError{ColumnIndex: i, FieldName: fieldDescriptions[i].Name, Err: err}
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -334,11 +334,16 @@ func (rows *baseRows) Conn() *Conn {
|
||||||
|
|
||||||
type ScanArgError struct {
|
type ScanArgError struct {
|
||||||
ColumnIndex int
|
ColumnIndex int
|
||||||
|
FieldName string
|
||||||
Err error
|
Err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e ScanArgError) Error() string {
|
func (e ScanArgError) Error() string {
|
||||||
return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err)
|
if e.FieldName == "?column?" { // Don't include the fieldname if it's unknown
|
||||||
|
return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("can't scan into dest[%d] (col: %s): %v", e.ColumnIndex, e.FieldName, e.Err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e ScanArgError) Unwrap() error {
|
func (e ScanArgError) Unwrap() error {
|
||||||
|
@ -366,7 +371,7 @@ func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, v
|
||||||
|
|
||||||
err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d)
|
err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ScanArgError{ColumnIndex: i, Err: err}
|
return ScanArgError{ColumnIndex: i, FieldName: fieldDescriptions[i].Name, Err: err}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -468,6 +473,8 @@ func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
|
||||||
return value, err
|
return value, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The defer rows.Close() won't have executed yet. If the query returned more than one row, rows would still be open.
|
||||||
|
// rows.Close() must be called before rows.Err() so we explicitly call it here.
|
||||||
rows.Close()
|
rows.Close()
|
||||||
return value, rows.Err()
|
return value, rows.Err()
|
||||||
}
|
}
|
||||||
|
|
5
tx.go
5
tx.go
|
@ -3,7 +3,6 @@ package pgx
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -103,7 +102,7 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// begin should never fail unless there is an underlying connection issue or
|
// begin should never fail unless there is an underlying connection issue or
|
||||||
// a context timeout. In either case, the connection is possibly broken.
|
// a context timeout. In either case, the connection is possibly broken.
|
||||||
c.die(errors.New("failed to begin transaction"))
|
c.die()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -216,7 +215,7 @@ func (tx *dbTx) Rollback(ctx context.Context) error {
|
||||||
tx.closed = true
|
tx.closed = true
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// A rollback failure leaves the connection in an undefined state
|
// A rollback failure leaves the connection in an undefined state
|
||||||
tx.conn.die(fmt.Errorf("rollback failed: %w", err))
|
tx.conn.die()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ package pgx_test
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
@ -215,7 +216,12 @@ func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typena
|
||||||
input := []int{1, 2, 234432}
|
input := []int{1, 2, 234432}
|
||||||
var output []int16
|
var output []int16
|
||||||
err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
|
err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
|
||||||
if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" {
|
fieldName := typename
|
||||||
|
if conn.PgConn().ParameterStatus("crdb_version") != "" && typename == "json" {
|
||||||
|
fieldName = "jsonb" // Seems like CockroachDB treats json as jsonb.
|
||||||
|
}
|
||||||
|
expectedMessage := fmt.Sprintf("can't scan into dest[0] (col: %s): json: cannot unmarshal number 234432 into Go value of type int16", fieldName)
|
||||||
|
if err == nil || err.Error() != expectedMessage {
|
||||||
t.Errorf("%s: Expected *json.UnmarshalTypeError, but got %v", typename, err)
|
t.Errorf("%s: Expected *json.UnmarshalTypeError, but got %v", typename, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue