mirror of https://github.com/jackc/pgx.git
Compare commits
84 Commits
Author | SHA1 | Date |
---|---|---|
|
0f77a2d028 | |
|
ddd966f09f | |
|
924834b5b4 | |
|
9b15554c51 | |
|
037e4cf9a2 | |
|
04bcc0219d | |
|
0e0a7d8344 | |
|
63422c7d6c | |
|
5c1fbf4806 | |
|
05fe5f8b05 | |
|
70c9a147a2 | |
|
6603ddfbe4 | |
|
70f7cad222 | |
|
6bf1b0b1b9 | |
|
14bda65a0c | |
|
9e3c4fb40f | |
|
05e72a5ab1 | |
|
47d631e34b | |
|
58b05f567c | |
|
dcb7193669 | |
|
1abf7d9050 | |
|
b5efc90a32 | |
|
a26c93551f | |
|
2100e1da46 | |
|
2d21a2b80d | |
|
5f33ee5f07 | |
|
228cfffc20 | |
|
a5353af354 | |
|
0bc29e3000 | |
|
9cce05944a | |
|
9c0ad690a9 | |
|
03f08abda3 | |
|
2c1b1c389a | |
|
329cb45913 | |
|
c96a55f8c0 | |
|
e87760682f | |
|
f681632c68 | |
|
3c640a44b6 | |
|
de3f868c1d | |
|
5424d3c873 | |
|
42d3d00734 | |
|
cdc672cf3f | |
|
52e2858629 | |
|
e352784fed | |
|
c2175fe46e | |
|
659823f8f3 | |
|
ca04098fab | |
|
4ff0a454e0 | |
|
00b86ca3db | |
|
61a0227241 | |
|
2190a8e0d1 | |
|
6e9fa42fef | |
|
6d9e6a726e | |
|
02e387ea64 | |
|
e452f80b1d | |
|
da0315d1a4 | |
|
120c89fe0d | |
|
057937db27 | |
|
47cbd8edb8 | |
|
90a77b13b2 | |
|
59d6aa87b9 | |
|
39ffc8b7a4 | |
|
c4c1076d28 | |
|
4293b25262 | |
|
ea1e13a660 | |
|
58d4c0c94f | |
|
1752f7b4c1 | |
|
ee718a110d | |
|
546ad2f4e2 | |
|
efc2c9ff44 | |
|
aabed18db8 | |
|
afa974fb05 | |
|
12b37f3218 | |
|
bcf3fbd780 | |
|
f7c3d190ad | |
|
473a241b96 | |
|
311f72afdc | |
|
877111ceeb | |
|
dc3aea06b5 | |
|
e5d321f920 | |
|
17cd36818c | |
|
76593f37f7 | |
|
29751194ef | |
|
c1f4cbb5cd |
|
@ -14,18 +14,8 @@ jobs:
|
|||
strategy:
|
||||
matrix:
|
||||
go-version: ["1.22", "1.23"]
|
||||
pg-version: [12, 13, 14, 15, 16, cockroachdb]
|
||||
pg-version: [13, 14, 15, 16, 17, cockroachdb]
|
||||
include:
|
||||
- pg-version: 12
|
||||
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
|
||||
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
|
||||
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
||||
pgx-ssl-password: certpw
|
||||
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
||||
- pg-version: 13
|
||||
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
|
@ -66,6 +56,16 @@ jobs:
|
|||
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
||||
pgx-ssl-password: certpw
|
||||
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
||||
- pg-version: 17
|
||||
pgx-test-database: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-unix-socket-conn-string: "host=/var/run/postgresql dbname=pgx_test"
|
||||
pgx-test-tcp-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-scram-password-conn-string: "host=127.0.0.1 user=pgx_scram password=secret dbname=pgx_test"
|
||||
pgx-test-md5-password-conn-string: "host=127.0.0.1 user=pgx_md5 password=secret dbname=pgx_test"
|
||||
pgx-test-plain-password-conn-string: "host=127.0.0.1 user=pgx_pw password=secret dbname=pgx_test"
|
||||
pgx-test-tls-conn-string: "host=localhost user=pgx_ssl password=secret sslmode=verify-full sslrootcert=/tmp/ca.pem dbname=pgx_test"
|
||||
pgx-ssl-password: certpw
|
||||
pgx-test-tls-client-conn-string: "host=localhost user=pgx_sslcert sslmode=verify-full sslrootcert=/tmp/ca.pem sslcert=/tmp/pgx_sslcert.crt sslkey=/tmp/pgx_sslcert.key dbname=pgx_test"
|
||||
- pg-version: cockroachdb
|
||||
pgx-test-database: "postgresql://root@127.0.0.1:26257/pgx_test?sslmode=disable&experimental_enable_temp_tables=on"
|
||||
|
||||
|
|
18
CHANGELOG.md
18
CHANGELOG.md
|
@ -1,3 +1,20 @@
|
|||
# 5.7.4 (March 24, 2025)
|
||||
|
||||
* Fix / revert change to scanning JSON `null` (Felix Röhrich)
|
||||
|
||||
# 5.7.3 (March 21, 2025)
|
||||
|
||||
* Expose EmptyAcquireWaitTime in pgxpool.Stat (vamshiaruru32)
|
||||
* Improve SQL sanitizer performance (ninedraft)
|
||||
* Fix Scan confusion with json(b), sql.Scanner, and automatic dereferencing (moukoublen, felix-roehrich)
|
||||
* Fix Values() for xml type always returning nil instead of []byte
|
||||
* Add ability to send Flush message in pipeline mode (zenkovev)
|
||||
* Fix pgtype.Timestamp's JSON behavior to match PostgreSQL (pconstantinou)
|
||||
* Better error messages when scanning structs (logicbomb)
|
||||
* Fix handling of error on batch write (bonnefoa)
|
||||
* Match libpq's connection fallback behavior more closely (felix-roehrich)
|
||||
* Add MinIdleConns to pgxpool (djahandarie)
|
||||
|
||||
# 5.7.2 (December 21, 2024)
|
||||
|
||||
* Fix prepared statement already exists on batch prepare failure
|
||||
|
@ -9,6 +26,7 @@
|
|||
* Implement pgtype.UUID.String() (Konstantin Grachev)
|
||||
* Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions (Alexander Rumyantsev)
|
||||
* Update golang.org/x/crypto
|
||||
* Fix json(b) columns prefer sql.Scanner interface like database/sql (Ludovico Russo)
|
||||
|
||||
# 5.7.1 (September 10, 2024)
|
||||
|
||||
|
|
14
README.md
14
README.md
|
@ -92,7 +92,7 @@ See the presentation at Golang Estonia, [PGX Top to Bottom](https://www.youtube.
|
|||
|
||||
## Supported Go and PostgreSQL Versions
|
||||
|
||||
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.21 and higher and PostgreSQL 12 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
|
||||
pgx supports the same versions of Go and PostgreSQL that are supported by their respective teams. For [Go](https://golang.org/doc/devel/release.html#policy) that is the two most recent major releases and for [PostgreSQL](https://www.postgresql.org/support/versioning/) the major releases in the last 5 years. This means pgx supports Go 1.22 and higher and PostgreSQL 13 and higher. pgx also is tested against the latest version of [CockroachDB](https://www.cockroachlabs.com/product/).
|
||||
|
||||
## Version Policy
|
||||
|
||||
|
@ -172,3 +172,15 @@ Supports, structs, maps, slices and custom mapping functions.
|
|||
### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx)
|
||||
|
||||
Code first migration library for native pgx (no database/sql abstraction).
|
||||
|
||||
### [github.com/amirsalarsafaei/sqlc-pgx-monitoring](https://github.com/amirsalarsafaei/sqlc-pgx-monitoring)
|
||||
|
||||
A database monitoring/metrics library for pgx and sqlc. Trace, log and monitor your sqlc query performance using OpenTelemetry.
|
||||
|
||||
### [https://github.com/nikolayk812/pgx-outbox](https://github.com/nikolayk812/pgx-outbox)
|
||||
|
||||
Simple Golang implementation for transactional outbox pattern for PostgreSQL using jackc/pgx driver.
|
||||
|
||||
### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy)
|
||||
|
||||
Simplifies working with the pgx library, providing convenient scanning of nested structures.
|
||||
|
|
2
Rakefile
2
Rakefile
|
@ -2,7 +2,7 @@ require "erb"
|
|||
|
||||
rule '.go' => '.go.erb' do |task|
|
||||
erb = ERB.new(File.read(task.source))
|
||||
File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding))
|
||||
File.write(task.name, "// Code generated from #{task.source}. DO NOT EDIT.\n\n" + erb.result(binding))
|
||||
sh "goimports", "-w", task.name
|
||||
end
|
||||
|
||||
|
|
|
@ -42,8 +42,8 @@ fi
|
|||
|
||||
if [[ "${PGVERSION-}" =~ ^cockroach ]]
|
||||
then
|
||||
wget -qO- https://binaries.cockroachdb.com/cockroach-v23.1.3.linux-amd64.tgz | tar xvz
|
||||
sudo mv cockroach-v23.1.3.linux-amd64/cockroach /usr/local/bin/
|
||||
wget -qO- https://binaries.cockroachdb.com/cockroach-v24.3.3.linux-amd64.tgz | tar xvz
|
||||
sudo mv cockroach-v24.3.3.linux-amd64/cockroach /usr/local/bin/
|
||||
cockroach start-single-node --insecure --background --listen-addr=localhost
|
||||
cockroach sql --insecure -e 'create database pgx_test'
|
||||
fi
|
||||
|
|
21
conn.go
21
conn.go
|
@ -420,7 +420,7 @@ func (c *Conn) IsClosed() bool {
|
|||
return c.pgConn.IsClosed()
|
||||
}
|
||||
|
||||
func (c *Conn) die(err error) {
|
||||
func (c *Conn) die() {
|
||||
if c.IsClosed() {
|
||||
return
|
||||
}
|
||||
|
@ -588,14 +588,6 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription
|
|||
return result.CommandTag, result.Err
|
||||
}
|
||||
|
||||
type unknownArgumentTypeQueryExecModeExecError struct {
|
||||
arg any
|
||||
}
|
||||
|
||||
func (e *unknownArgumentTypeQueryExecModeExecError) Error() string {
|
||||
return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg)
|
||||
}
|
||||
|
||||
func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) {
|
||||
err := c.eqb.Build(c.typeMap, nil, args)
|
||||
if err != nil {
|
||||
|
@ -661,11 +653,12 @@ const (
|
|||
// should implement pgtype.Int64Valuer.
|
||||
QueryExecModeExec
|
||||
|
||||
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. Queries
|
||||
// are executed in a single round trip. Type mappings can be registered with pgtype.Map.RegisterDefaultPgType. Queries
|
||||
// will be rejected that have arguments that are unregistered or ambiguous. e.g. A map[string]string may have the
|
||||
// PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use a map[string]string directly as an
|
||||
// argument. This mode cannot.
|
||||
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments. This is
|
||||
// especially significant for []byte values. []byte values are encoded as PostgreSQL bytea. string must be used
|
||||
// instead for text type values including json and jsonb. Type mappings can be registered with
|
||||
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambiguous.
|
||||
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use a
|
||||
// map[string]string directly as an argument. This mode cannot. Queries are executed in a single round trip.
|
||||
//
|
||||
// QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec. This includes
|
||||
// the warning regarding differences in text format and binary format encoding with user defined types. There may be
|
||||
|
|
|
@ -1417,5 +1417,4 @@ func TestErrNoRows(t *testing.T) {
|
|||
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())
|
||||
|
||||
require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
|
||||
require.ErrorIs(t, pgx.ErrNoRows, pgx.ErrNoRows, "sql.ErrNowRows must match pgx.ErrNoRows")
|
||||
}
|
||||
|
|
|
@ -161,7 +161,7 @@ type derivedTypeInfo struct {
|
|||
// The result of this call can be passed into RegisterTypes to complete the process.
|
||||
func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) {
|
||||
m := c.TypeMap()
|
||||
if typeNames == nil || len(typeNames) == 0 {
|
||||
if len(typeNames) == 0 {
|
||||
return nil, fmt.Errorf("No type names were supplied.")
|
||||
}
|
||||
|
||||
|
@ -169,13 +169,7 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ
|
|||
// the SQL not support recent structures such as multirange
|
||||
serverVersion, _ := serverVersion(c)
|
||||
sql := buildLoadDerivedTypesSQL(serverVersion, typeNames)
|
||||
var rows Rows
|
||||
var err error
|
||||
if typeNames == nil {
|
||||
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol)
|
||||
} else {
|
||||
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
|
||||
}
|
||||
rows, err := c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("While generating load types query: %w", err)
|
||||
}
|
||||
|
@ -232,15 +226,15 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ
|
|||
default:
|
||||
return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName)
|
||||
}
|
||||
if type_ != nil {
|
||||
m.RegisterType(type_)
|
||||
if ti.NspName != "" {
|
||||
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}
|
||||
m.RegisterType(nspType)
|
||||
result = append(result, nspType)
|
||||
}
|
||||
result = append(result, type_)
|
||||
|
||||
// the type_ is imposible to be null
|
||||
m.RegisterType(type_)
|
||||
if ti.NspName != "" {
|
||||
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}
|
||||
m.RegisterType(nspType)
|
||||
result = append(result, nspType)
|
||||
}
|
||||
result = append(result, type_)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
current_branch=$(git rev-parse --abbrev-ref HEAD)
|
||||
if [ "$current_branch" == "HEAD" ]; then
|
||||
current_branch=$(git rev-parse HEAD)
|
||||
fi
|
||||
|
||||
restore_branch() {
|
||||
echo "Restoring original branch/commit: $current_branch"
|
||||
git checkout "$current_branch"
|
||||
}
|
||||
trap restore_branch EXIT
|
||||
|
||||
# Check if there are uncommitted changes
|
||||
if ! git diff --quiet || ! git diff --cached --quiet; then
|
||||
echo "There are uncommitted changes. Please commit or stash them before running this script."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Ensure that at least one commit argument is passed
|
||||
if [ "$#" -lt 1 ]; then
|
||||
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
commits=("$@")
|
||||
benchmarks_dir=benchmarks
|
||||
|
||||
if ! mkdir -p "${benchmarks_dir}"; then
|
||||
echo "Unable to create dir for benchmarks data"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Benchmark results
|
||||
bench_files=()
|
||||
|
||||
# Run benchmark for each listed commit
|
||||
for i in "${!commits[@]}"; do
|
||||
commit="${commits[i]}"
|
||||
git checkout "$commit" || {
|
||||
echo "Failed to checkout $commit"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Sanitized commmit message
|
||||
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
|
||||
|
||||
# Benchmark data will go there
|
||||
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
|
||||
|
||||
if ! go test -bench=. -count=10 >"$bench_file"; then
|
||||
echo "Benchmarking failed for commit $commit"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
bench_files+=("$bench_file")
|
||||
done
|
||||
|
||||
# go install golang.org/x/perf/cmd/benchstat[@latest]
|
||||
benchstat "${bench_files[@]}"
|
|
@ -4,8 +4,10 @@ import (
|
|||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
@ -24,18 +26,33 @@ type Query struct {
|
|||
// https://github.com/jackc/pgx/issues/1380
|
||||
const replacementcharacterwidth = 3
|
||||
|
||||
const maxBufSize = 16384 // 16 Ki
|
||||
|
||||
var bufPool = &pool[*bytes.Buffer]{
|
||||
new: func() *bytes.Buffer {
|
||||
return &bytes.Buffer{}
|
||||
},
|
||||
reset: func(b *bytes.Buffer) bool {
|
||||
n := b.Len()
|
||||
b.Reset()
|
||||
return n < maxBufSize
|
||||
},
|
||||
}
|
||||
|
||||
var null = []byte("null")
|
||||
|
||||
func (q *Query) Sanitize(args ...any) (string, error) {
|
||||
argUse := make([]bool, len(args))
|
||||
buf := &bytes.Buffer{}
|
||||
buf := bufPool.get()
|
||||
defer bufPool.put(buf)
|
||||
|
||||
for _, part := range q.Parts {
|
||||
var str string
|
||||
switch part := part.(type) {
|
||||
case string:
|
||||
str = part
|
||||
buf.WriteString(part)
|
||||
case int:
|
||||
argIdx := part - 1
|
||||
|
||||
var p []byte
|
||||
if argIdx < 0 {
|
||||
return "", fmt.Errorf("first sql argument must be > 0")
|
||||
}
|
||||
|
@ -43,34 +60,41 @@ func (q *Query) Sanitize(args ...any) (string, error) {
|
|||
if argIdx >= len(args) {
|
||||
return "", fmt.Errorf("insufficient arguments")
|
||||
}
|
||||
|
||||
// Prevent SQL injection via Line Comment Creation
|
||||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||
buf.WriteByte(' ')
|
||||
|
||||
arg := args[argIdx]
|
||||
switch arg := arg.(type) {
|
||||
case nil:
|
||||
str = "null"
|
||||
p = null
|
||||
case int64:
|
||||
str = strconv.FormatInt(arg, 10)
|
||||
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
|
||||
case float64:
|
||||
str = strconv.FormatFloat(arg, 'f', -1, 64)
|
||||
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
|
||||
case bool:
|
||||
str = strconv.FormatBool(arg)
|
||||
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
|
||||
case []byte:
|
||||
str = QuoteBytes(arg)
|
||||
p = QuoteBytes(buf.AvailableBuffer(), arg)
|
||||
case string:
|
||||
str = QuoteString(arg)
|
||||
p = QuoteString(buf.AvailableBuffer(), arg)
|
||||
case time.Time:
|
||||
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||
p = arg.Truncate(time.Microsecond).
|
||||
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||
default:
|
||||
return "", fmt.Errorf("invalid arg type: %T", arg)
|
||||
}
|
||||
argUse[argIdx] = true
|
||||
|
||||
buf.Write(p)
|
||||
|
||||
// Prevent SQL injection via Line Comment Creation
|
||||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||
str = " " + str + " "
|
||||
buf.WriteByte(' ')
|
||||
default:
|
||||
return "", fmt.Errorf("invalid Part type: %T", part)
|
||||
}
|
||||
buf.WriteString(str)
|
||||
}
|
||||
|
||||
for i, used := range argUse {
|
||||
|
@ -82,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) {
|
|||
}
|
||||
|
||||
func NewQuery(sql string) (*Query, error) {
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
query := &Query{}
|
||||
query.init(sql)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
var sqlLexerPool = &pool[*sqlLexer]{
|
||||
new: func() *sqlLexer {
|
||||
return &sqlLexer{}
|
||||
},
|
||||
reset: func(sl *sqlLexer) bool {
|
||||
*sl = sqlLexer{}
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
func (q *Query) init(sql string) {
|
||||
parts := q.Parts[:0]
|
||||
if parts == nil {
|
||||
// dirty, but fast heuristic to preallocate for ~90% usecases
|
||||
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
|
||||
parts = make([]Part, 0, n)
|
||||
}
|
||||
|
||||
l := sqlLexerPool.get()
|
||||
defer sqlLexerPool.put(l)
|
||||
|
||||
l.src = sql
|
||||
l.stateFn = rawState
|
||||
l.parts = parts
|
||||
|
||||
for l.stateFn != nil {
|
||||
l.stateFn = l.stateFn(l)
|
||||
}
|
||||
|
||||
query := &Query{Parts: l.parts}
|
||||
|
||||
return query, nil
|
||||
q.Parts = l.parts
|
||||
}
|
||||
|
||||
func QuoteString(str string) string {
|
||||
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
|
||||
func QuoteString(dst []byte, str string) []byte {
|
||||
const quote = '\''
|
||||
|
||||
// Preallocate space for the worst case scenario
|
||||
dst = slices.Grow(dst, len(str)*2+2)
|
||||
|
||||
// Add opening quote
|
||||
dst = append(dst, quote)
|
||||
|
||||
// Iterate through the string without allocating
|
||||
for i := 0; i < len(str); i++ {
|
||||
if str[i] == quote {
|
||||
dst = append(dst, quote, quote)
|
||||
} else {
|
||||
dst = append(dst, str[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Add closing quote
|
||||
dst = append(dst, quote)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func QuoteBytes(buf []byte) string {
|
||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||
func QuoteBytes(dst, buf []byte) []byte {
|
||||
if len(buf) == 0 {
|
||||
return append(dst, `'\x'`...)
|
||||
}
|
||||
|
||||
// Calculate required length
|
||||
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
|
||||
|
||||
// Ensure dst has enough capacity
|
||||
if cap(dst)-len(dst) < requiredLen {
|
||||
newDst := make([]byte, len(dst), len(dst)+requiredLen)
|
||||
copy(newDst, dst)
|
||||
dst = newDst
|
||||
}
|
||||
|
||||
// Record original length and extend slice
|
||||
origLen := len(dst)
|
||||
dst = dst[:origLen+requiredLen]
|
||||
|
||||
// Add prefix
|
||||
dst[origLen] = '\''
|
||||
dst[origLen+1] = '\\'
|
||||
dst[origLen+2] = 'x'
|
||||
|
||||
// Encode bytes directly into dst
|
||||
hex.Encode(dst[origLen+3:len(dst)-1], buf)
|
||||
|
||||
// Add suffix
|
||||
dst[len(dst)-1] = '\''
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
type sqlLexer struct {
|
||||
|
@ -319,13 +416,45 @@ func multilineCommentState(l *sqlLexer) stateFn {
|
|||
}
|
||||
}
|
||||
|
||||
var queryPool = &pool[*Query]{
|
||||
new: func() *Query {
|
||||
return &Query{}
|
||||
},
|
||||
reset: func(q *Query) bool {
|
||||
n := len(q.Parts)
|
||||
q.Parts = q.Parts[:0]
|
||||
return n < 64 // drop too large queries
|
||||
},
|
||||
}
|
||||
|
||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||
// as necessary. This function is only safe when standard_conforming_strings is
|
||||
// on.
|
||||
func SanitizeSQL(sql string, args ...any) (string, error) {
|
||||
query, err := NewQuery(sql)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
query := queryPool.get()
|
||||
query.init(sql)
|
||||
defer queryPool.put(query)
|
||||
|
||||
return query.Sanitize(args...)
|
||||
}
|
||||
|
||||
type pool[E any] struct {
|
||||
p sync.Pool
|
||||
new func() E
|
||||
reset func(E) bool
|
||||
}
|
||||
|
||||
func (pool *pool[E]) get() E {
|
||||
v, ok := pool.p.Get().(E)
|
||||
if !ok {
|
||||
v = pool.new()
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func (p *pool[E]) put(v E) {
|
||||
if p.reset(v) {
|
||||
p.p.Put(v)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
// sanitize_benchmark_test.go
|
||||
package sanitize_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||
)
|
||||
|
||||
var benchmarkSanitizeResult string
|
||||
|
||||
const benchmarkQuery = "" +
|
||||
`SELECT *
|
||||
FROM "water_containers"
|
||||
WHERE NOT "id" = $1 -- int64
|
||||
AND "tags" NOT IN $2 -- nil
|
||||
AND "volume" > $3 -- float64
|
||||
AND "transportable" = $4 -- bool
|
||||
AND position($5 IN "sign") -- bytes
|
||||
AND "label" LIKE $6 -- string
|
||||
AND "created_at" > $7; -- time.Time`
|
||||
|
||||
var benchmarkArgs = []any{
|
||||
int64(12345),
|
||||
nil,
|
||||
float64(500),
|
||||
true,
|
||||
[]byte("8BADF00D"),
|
||||
"kombucha's han'dy awokowa",
|
||||
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
func BenchmarkSanitize(b *testing.B) {
|
||||
query, err := sanitize.NewQuery(benchmarkQuery)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create query: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to sanitize query: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var benchmarkNewSQLResult string
|
||||
|
||||
func BenchmarkSanitizeSQL(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
var err error
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to sanitize SQL: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
package sanitize_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||
)
|
||||
|
||||
func FuzzQuoteString(f *testing.F) {
|
||||
const prefix = "prefix"
|
||||
f.Add("new\nline")
|
||||
f.Add("sample text")
|
||||
f.Add("sample q'u'o't'e's")
|
||||
f.Add("select 'quoted $42', $1")
|
||||
|
||||
f.Fuzz(func(t *testing.T, input string) {
|
||||
got := string(sanitize.QuoteString([]byte(prefix), input))
|
||||
want := oldQuoteString(input)
|
||||
|
||||
quoted, ok := strings.CutPrefix(got, prefix)
|
||||
if !ok {
|
||||
t.Fatalf("result has no prefix")
|
||||
}
|
||||
|
||||
if want != quoted {
|
||||
t.Errorf("got %q", got)
|
||||
t.Fatalf("want %q", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzQuoteBytes(f *testing.F) {
|
||||
const prefix = "prefix"
|
||||
f.Add([]byte(nil))
|
||||
f.Add([]byte("\n"))
|
||||
f.Add([]byte("sample text"))
|
||||
f.Add([]byte("sample q'u'o't'e's"))
|
||||
f.Add([]byte("select 'quoted $42', $1"))
|
||||
|
||||
f.Fuzz(func(t *testing.T, input []byte) {
|
||||
got := string(sanitize.QuoteBytes([]byte(prefix), input))
|
||||
want := oldQuoteBytes(input)
|
||||
|
||||
quoted, ok := strings.CutPrefix(got, prefix)
|
||||
if !ok {
|
||||
t.Fatalf("result has no prefix")
|
||||
}
|
||||
|
||||
if want != quoted {
|
||||
t.Errorf("got %q", got)
|
||||
t.Fatalf("want %q", want)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -1,6 +1,8 @@
|
|||
package sanitize_test
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -227,3 +229,55 @@ func TestQuerySanitize(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteString(t *testing.T) {
|
||||
tc := func(name, input string) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := string(sanitize.QuoteString(nil, input))
|
||||
want := oldQuoteString(input)
|
||||
|
||||
if got != want {
|
||||
t.Errorf("got: %s", got)
|
||||
t.Fatalf("want: %s", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
tc("empty", "")
|
||||
tc("text", "abcd")
|
||||
tc("with quotes", `one's hat is always a cat`)
|
||||
}
|
||||
|
||||
// This function was used before optimizations.
|
||||
// You should keep for testing purposes - we want to ensure there are no breaking changes.
|
||||
func oldQuoteString(str string) string {
|
||||
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
|
||||
}
|
||||
|
||||
func TestQuoteBytes(t *testing.T) {
|
||||
tc := func(name string, input []byte) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := string(sanitize.QuoteBytes(nil, input))
|
||||
want := oldQuoteBytes(input)
|
||||
|
||||
if got != want {
|
||||
t.Errorf("got: %s", got)
|
||||
t.Fatalf("want: %s", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
tc("nil", nil)
|
||||
tc("empty", []byte{})
|
||||
tc("text", []byte("abcd"))
|
||||
}
|
||||
|
||||
// This function was used before optimizations.
|
||||
// You should keep for testing purposes - we want to ensure there are no breaking changes.
|
||||
func oldQuoteBytes(buf []byte) string {
|
||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||
}
|
||||
|
|
|
@ -51,6 +51,8 @@ type Config struct {
|
|||
KerberosSpn string
|
||||
Fallbacks []*FallbackConfig
|
||||
|
||||
SSLNegotiation string // sslnegotiation=postgres or sslnegotiation=direct
|
||||
|
||||
// ValidateConnect is called during a connection attempt after a successful authentication with the PostgreSQL server.
|
||||
// It can be used to validate that the server is acceptable. If this returns an error the connection is closed and the next
|
||||
// fallback config is tried. This allows implementing high availability behavior such as libpq does with target_session_attrs.
|
||||
|
@ -318,6 +320,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
|||
"sslkey": {},
|
||||
"sslcert": {},
|
||||
"sslrootcert": {},
|
||||
"sslnegotiation": {},
|
||||
"sslpassword": {},
|
||||
"sslsni": {},
|
||||
"krbspn": {},
|
||||
|
@ -386,6 +389,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
|||
config.Port = fallbacks[0].Port
|
||||
config.TLSConfig = fallbacks[0].TLSConfig
|
||||
config.Fallbacks = fallbacks[1:]
|
||||
config.SSLNegotiation = settings["sslnegotiation"]
|
||||
|
||||
passfile, err := pgpassfile.ReadPassfile(settings["passfile"])
|
||||
if err == nil {
|
||||
|
@ -449,6 +453,7 @@ func parseEnvSettings() map[string]string {
|
|||
"PGSSLSNI": "sslsni",
|
||||
"PGSSLROOTCERT": "sslrootcert",
|
||||
"PGSSLPASSWORD": "sslpassword",
|
||||
"PGSSLNEGOTIATION": "sslnegotiation",
|
||||
"PGTARGETSESSIONATTRS": "target_session_attrs",
|
||||
"PGSERVICE": "service",
|
||||
"PGSERVICEFILE": "servicefile",
|
||||
|
@ -646,6 +651,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
|||
sslkey := settings["sslkey"]
|
||||
sslpassword := settings["sslpassword"]
|
||||
sslsni := settings["sslsni"]
|
||||
sslnegotiation := settings["sslnegotiation"]
|
||||
|
||||
// Match libpq default behavior
|
||||
if sslmode == "" {
|
||||
|
@ -657,6 +663,13 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
|||
|
||||
tlsConfig := &tls.Config{}
|
||||
|
||||
if sslnegotiation == "direct" {
|
||||
tlsConfig.NextProtos = []string{"postgresql"}
|
||||
if sslmode == "prefer" {
|
||||
sslmode = "require"
|
||||
}
|
||||
}
|
||||
|
||||
if sslrootcert != "" {
|
||||
var caCertPool *x509.CertPool
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ func TestContextWatcherContextCancelled(t *testing.T) {
|
|||
require.True(t, cleanupCalled, "Cleanup func was not called")
|
||||
}
|
||||
|
||||
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
|
||||
func TestContextWatcherUnwatchedBeforeContextCancelled(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{
|
||||
handleCancel: func(context.Context) {
|
||||
t.Error("cancel func should not have been called")
|
||||
|
|
250
pgconn/pgconn.go
250
pgconn/pgconn.go
|
@ -1,6 +1,7 @@
|
|||
package pgconn
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"crypto/tls"
|
||||
|
@ -267,12 +268,15 @@ func connectPreferred(ctx context.Context, config *Config, connectOneConfigs []*
|
|||
|
||||
var pgErr *PgError
|
||||
if errors.As(err, &pgErr) {
|
||||
const ERRCODE_INVALID_PASSWORD = "28P01" // wrong password
|
||||
const ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION = "28000" // wrong password or bad pg_hba.conf settings
|
||||
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
|
||||
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
|
||||
// pgx will try next host even if libpq does not in certain cases (see #2246)
|
||||
// consider change for the next major version
|
||||
|
||||
const ERRCODE_INVALID_PASSWORD = "28P01"
|
||||
const ERRCODE_INVALID_CATALOG_NAME = "3D000" // db does not exist
|
||||
const ERRCODE_INSUFFICIENT_PRIVILEGE = "42501" // missing connect privilege
|
||||
|
||||
// auth failed due to invalid password, db does not exist or user has no permission
|
||||
if pgErr.Code == ERRCODE_INVALID_PASSWORD ||
|
||||
pgErr.Code == ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION && c.tlsConfig != nil ||
|
||||
pgErr.Code == ERRCODE_INVALID_CATALOG_NAME ||
|
||||
pgErr.Code == ERRCODE_INSUFFICIENT_PRIVILEGE {
|
||||
return nil, allErrors
|
||||
|
@ -321,7 +325,15 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo
|
|||
if connectConfig.tlsConfig != nil {
|
||||
pgConn.contextWatcher = ctxwatch.NewContextWatcher(&DeadlineContextWatcherHandler{Conn: pgConn.conn})
|
||||
pgConn.contextWatcher.Watch(ctx)
|
||||
tlsConn, err := startTLS(pgConn.conn, connectConfig.tlsConfig)
|
||||
var (
|
||||
tlsConn net.Conn
|
||||
err error
|
||||
)
|
||||
if config.SSLNegotiation == "direct" {
|
||||
tlsConn = tls.Client(pgConn.conn, connectConfig.tlsConfig)
|
||||
} else {
|
||||
tlsConn, err = startTLS(pgConn.conn, connectConfig.tlsConfig)
|
||||
}
|
||||
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
|
||||
if err != nil {
|
||||
pgConn.conn.Close()
|
||||
|
@ -1408,9 +1420,8 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||
|
||||
// MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch.
|
||||
type MultiResultReader struct {
|
||||
pgConn *PgConn
|
||||
ctx context.Context
|
||||
pipeline *Pipeline
|
||||
pgConn *PgConn
|
||||
ctx context.Context
|
||||
|
||||
rr *ResultReader
|
||||
|
||||
|
@ -1443,12 +1454,8 @@ func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error)
|
|||
switch msg := msg.(type) {
|
||||
case *pgproto3.ReadyForQuery:
|
||||
mrr.closed = true
|
||||
if mrr.pipeline != nil {
|
||||
mrr.pipeline.expectedReadyForQueryCount--
|
||||
} else {
|
||||
mrr.pgConn.contextWatcher.Unwatch()
|
||||
mrr.pgConn.unlock()
|
||||
}
|
||||
mrr.pgConn.contextWatcher.Unwatch()
|
||||
mrr.pgConn.unlock()
|
||||
case *pgproto3.ErrorResponse:
|
||||
mrr.err = ErrorResponseToPgError(msg)
|
||||
}
|
||||
|
@ -1672,7 +1679,11 @@ func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error
|
|||
case *pgproto3.EmptyQueryResponse:
|
||||
rr.concludeCommand(CommandTag{}, nil)
|
||||
case *pgproto3.ErrorResponse:
|
||||
rr.concludeCommand(CommandTag{}, ErrorResponseToPgError(msg))
|
||||
pgErr := ErrorResponseToPgError(msg)
|
||||
if rr.pipeline != nil {
|
||||
rr.pipeline.state.HandleError(pgErr)
|
||||
}
|
||||
rr.concludeCommand(CommandTag{}, pgErr)
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
|
@ -1773,9 +1784,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
|||
|
||||
batch.buf, batch.err = (&pgproto3.Sync{}).Encode(batch.buf)
|
||||
if batch.err != nil {
|
||||
pgConn.contextWatcher.Unwatch()
|
||||
multiResult.err = normalizeTimeoutError(multiResult.ctx, batch.err)
|
||||
multiResult.closed = true
|
||||
multiResult.err = batch.err
|
||||
pgConn.unlock()
|
||||
pgConn.asyncClose()
|
||||
return multiResult
|
||||
}
|
||||
|
||||
|
@ -1783,9 +1795,10 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
|
|||
defer pgConn.exitPotentialWriteReadDeadlock()
|
||||
_, err := pgConn.conn.Write(batch.buf)
|
||||
if err != nil {
|
||||
pgConn.contextWatcher.Unwatch()
|
||||
multiResult.err = normalizeTimeoutError(multiResult.ctx, err)
|
||||
multiResult.closed = true
|
||||
multiResult.err = err
|
||||
pgConn.unlock()
|
||||
pgConn.asyncClose()
|
||||
return multiResult
|
||||
}
|
||||
|
||||
|
@ -1999,9 +2012,7 @@ type Pipeline struct {
|
|||
conn *PgConn
|
||||
ctx context.Context
|
||||
|
||||
expectedReadyForQueryCount int
|
||||
pendingSync bool
|
||||
|
||||
state pipelineState
|
||||
err error
|
||||
closed bool
|
||||
}
|
||||
|
@ -2012,6 +2023,122 @@ type PipelineSync struct{}
|
|||
// CloseComplete is returned by GetResults when a CloseComplete message is received.
|
||||
type CloseComplete struct{}
|
||||
|
||||
type pipelineRequestType int
|
||||
|
||||
const (
|
||||
pipelineNil pipelineRequestType = iota
|
||||
pipelinePrepare
|
||||
pipelineQueryParams
|
||||
pipelineQueryPrepared
|
||||
pipelineDeallocate
|
||||
pipelineSyncRequest
|
||||
pipelineFlushRequest
|
||||
)
|
||||
|
||||
type pipelineRequestEvent struct {
|
||||
RequestType pipelineRequestType
|
||||
WasSentToServer bool
|
||||
BeforeFlushOrSync bool
|
||||
}
|
||||
|
||||
type pipelineState struct {
|
||||
requestEventQueue list.List
|
||||
lastRequestType pipelineRequestType
|
||||
pgErr *PgError
|
||||
expectedReadyForQueryCount int
|
||||
}
|
||||
|
||||
func (s *pipelineState) Init() {
|
||||
s.requestEventQueue.Init()
|
||||
s.lastRequestType = pipelineNil
|
||||
}
|
||||
|
||||
func (s *pipelineState) RegisterSendingToServer() {
|
||||
for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() {
|
||||
val := elem.Value.(pipelineRequestEvent)
|
||||
if val.WasSentToServer {
|
||||
return
|
||||
}
|
||||
val.WasSentToServer = true
|
||||
elem.Value = val
|
||||
}
|
||||
}
|
||||
|
||||
func (s *pipelineState) registerFlushingBufferOnServer() {
|
||||
for elem := s.requestEventQueue.Back(); elem != nil; elem = elem.Prev() {
|
||||
val := elem.Value.(pipelineRequestEvent)
|
||||
if val.BeforeFlushOrSync {
|
||||
return
|
||||
}
|
||||
val.BeforeFlushOrSync = true
|
||||
elem.Value = val
|
||||
}
|
||||
}
|
||||
|
||||
func (s *pipelineState) PushBackRequestType(req pipelineRequestType) {
|
||||
if req == pipelineNil {
|
||||
return
|
||||
}
|
||||
|
||||
if req != pipelineFlushRequest {
|
||||
s.requestEventQueue.PushBack(pipelineRequestEvent{RequestType: req})
|
||||
}
|
||||
if req == pipelineFlushRequest || req == pipelineSyncRequest {
|
||||
s.registerFlushingBufferOnServer()
|
||||
}
|
||||
s.lastRequestType = req
|
||||
|
||||
if req == pipelineSyncRequest {
|
||||
s.expectedReadyForQueryCount++
|
||||
}
|
||||
}
|
||||
|
||||
func (s *pipelineState) ExtractFrontRequestType() pipelineRequestType {
|
||||
for {
|
||||
elem := s.requestEventQueue.Front()
|
||||
if elem == nil {
|
||||
return pipelineNil
|
||||
}
|
||||
val := elem.Value.(pipelineRequestEvent)
|
||||
if !(val.WasSentToServer && val.BeforeFlushOrSync) {
|
||||
return pipelineNil
|
||||
}
|
||||
|
||||
s.requestEventQueue.Remove(elem)
|
||||
if val.RequestType == pipelineSyncRequest {
|
||||
s.pgErr = nil
|
||||
}
|
||||
if s.pgErr == nil {
|
||||
return val.RequestType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *pipelineState) HandleError(err *PgError) {
|
||||
s.pgErr = err
|
||||
}
|
||||
|
||||
func (s *pipelineState) HandleReadyForQuery() {
|
||||
s.expectedReadyForQueryCount--
|
||||
}
|
||||
|
||||
func (s *pipelineState) PendingSync() bool {
|
||||
var notPendingSync bool
|
||||
|
||||
if elem := s.requestEventQueue.Back(); elem != nil {
|
||||
val := elem.Value.(pipelineRequestEvent)
|
||||
notPendingSync = (val.RequestType == pipelineSyncRequest) && val.WasSentToServer
|
||||
} else {
|
||||
notPendingSync = (s.lastRequestType == pipelineSyncRequest) || (s.lastRequestType == pipelineNil)
|
||||
}
|
||||
|
||||
return !notPendingSync
|
||||
}
|
||||
|
||||
func (s *pipelineState) ExpectedReadyForQuery() int {
|
||||
return s.expectedReadyForQueryCount
|
||||
}
|
||||
|
||||
// StartPipeline switches the connection to pipeline mode and returns a *Pipeline. In pipeline mode requests can be sent
|
||||
// to the server without waiting for a response. Close must be called on the returned *Pipeline to return the connection
|
||||
// to normal mode. While in pipeline mode, no methods that communicate with the server may be called except
|
||||
|
@ -2020,16 +2147,21 @@ type CloseComplete struct{}
|
|||
// Prefer ExecBatch when only sending one group of queries at once.
|
||||
func (pgConn *PgConn) StartPipeline(ctx context.Context) *Pipeline {
|
||||
if err := pgConn.lock(); err != nil {
|
||||
return &Pipeline{
|
||||
pipeline := &Pipeline{
|
||||
closed: true,
|
||||
err: err,
|
||||
}
|
||||
pipeline.state.Init()
|
||||
|
||||
return pipeline
|
||||
}
|
||||
|
||||
pgConn.pipeline = Pipeline{
|
||||
conn: pgConn,
|
||||
ctx: ctx,
|
||||
}
|
||||
pgConn.pipeline.state.Init()
|
||||
|
||||
pipeline := &pgConn.pipeline
|
||||
|
||||
if ctx != context.Background() {
|
||||
|
@ -2052,10 +2184,10 @@ func (p *Pipeline) SendPrepare(name, sql string, paramOIDs []uint32) {
|
|||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.pendingSync = true
|
||||
|
||||
p.conn.frontend.SendParse(&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs})
|
||||
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'S', Name: name})
|
||||
p.state.PushBackRequestType(pipelinePrepare)
|
||||
}
|
||||
|
||||
// SendDeallocate deallocates a prepared statement.
|
||||
|
@ -2063,9 +2195,9 @@ func (p *Pipeline) SendDeallocate(name string) {
|
|||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.pendingSync = true
|
||||
|
||||
p.conn.frontend.SendClose(&pgproto3.Close{ObjectType: 'S', Name: name})
|
||||
p.state.PushBackRequestType(pipelineDeallocate)
|
||||
}
|
||||
|
||||
// SendQueryParams is the pipeline version of *PgConn.QueryParams.
|
||||
|
@ -2073,12 +2205,12 @@ func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs [
|
|||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.pendingSync = true
|
||||
|
||||
p.conn.frontend.SendParse(&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs})
|
||||
p.conn.frontend.SendBind(&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
|
||||
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
|
||||
p.conn.frontend.SendExecute(&pgproto3.Execute{})
|
||||
p.state.PushBackRequestType(pipelineQueryParams)
|
||||
}
|
||||
|
||||
// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared.
|
||||
|
@ -2086,11 +2218,42 @@ func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, para
|
|||
if p.closed {
|
||||
return
|
||||
}
|
||||
p.pendingSync = true
|
||||
|
||||
p.conn.frontend.SendBind(&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats})
|
||||
p.conn.frontend.SendDescribe(&pgproto3.Describe{ObjectType: 'P'})
|
||||
p.conn.frontend.SendExecute(&pgproto3.Execute{})
|
||||
p.state.PushBackRequestType(pipelineQueryPrepared)
|
||||
}
|
||||
|
||||
// SendFlushRequest sends a request for the server to flush its output buffer.
|
||||
//
|
||||
// The server flushes its output buffer automatically as a result of Sync being called,
|
||||
// or on any request when not in pipeline mode; this function is useful to cause the server
|
||||
// to flush its output buffer in pipeline mode without establishing a synchronization point.
|
||||
// Note that the request is not itself flushed to the server automatically; use Flush if
|
||||
// necessary. This copies the behavior of libpq PQsendFlushRequest.
|
||||
func (p *Pipeline) SendFlushRequest() {
|
||||
if p.closed {
|
||||
return
|
||||
}
|
||||
|
||||
p.conn.frontend.Send(&pgproto3.Flush{})
|
||||
p.state.PushBackRequestType(pipelineFlushRequest)
|
||||
}
|
||||
|
||||
// SendPipelineSync marks a synchronization point in a pipeline by sending a sync message
|
||||
// without flushing the send buffer. This serves as the delimiter of an implicit
|
||||
// transaction and an error recovery point.
|
||||
//
|
||||
// Note that the request is not itself flushed to the server automatically; use Flush if
|
||||
// necessary. This copies the behavior of libpq PQsendPipelineSync.
|
||||
func (p *Pipeline) SendPipelineSync() {
|
||||
if p.closed {
|
||||
return
|
||||
}
|
||||
|
||||
p.conn.frontend.SendSync(&pgproto3.Sync{})
|
||||
p.state.PushBackRequestType(pipelineSyncRequest)
|
||||
}
|
||||
|
||||
// Flush flushes the queued requests without establishing a synchronization point.
|
||||
|
@ -2115,28 +2278,14 @@ func (p *Pipeline) Flush() error {
|
|||
return err
|
||||
}
|
||||
|
||||
p.state.RegisterSendingToServer()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sync establishes a synchronization point and flushes the queued requests.
|
||||
func (p *Pipeline) Sync() error {
|
||||
if p.closed {
|
||||
if p.err != nil {
|
||||
return p.err
|
||||
}
|
||||
return errors.New("pipeline closed")
|
||||
}
|
||||
|
||||
p.conn.frontend.SendSync(&pgproto3.Sync{})
|
||||
err := p.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.pendingSync = false
|
||||
p.expectedReadyForQueryCount++
|
||||
|
||||
return nil
|
||||
p.SendPipelineSync()
|
||||
return p.Flush()
|
||||
}
|
||||
|
||||
// GetResults gets the next results. If results are present, results may be a *ResultReader, *StatementDescription, or
|
||||
|
@ -2150,7 +2299,7 @@ func (p *Pipeline) GetResults() (results any, err error) {
|
|||
return nil, errors.New("pipeline closed")
|
||||
}
|
||||
|
||||
if p.expectedReadyForQueryCount == 0 {
|
||||
if p.state.ExtractFrontRequestType() == pipelineNil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
@ -2195,13 +2344,13 @@ func (p *Pipeline) getResults() (results any, err error) {
|
|||
case *pgproto3.CloseComplete:
|
||||
return &CloseComplete{}, nil
|
||||
case *pgproto3.ReadyForQuery:
|
||||
p.expectedReadyForQueryCount--
|
||||
p.state.HandleReadyForQuery()
|
||||
return &PipelineSync{}, nil
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgErr := ErrorResponseToPgError(msg)
|
||||
p.state.HandleError(pgErr)
|
||||
return nil, pgErr
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2231,6 +2380,7 @@ func (p *Pipeline) getResultsPrepare() (*StatementDescription, error) {
|
|||
// These should never happen here. But don't take chances that could lead to a deadlock.
|
||||
case *pgproto3.ErrorResponse:
|
||||
pgErr := ErrorResponseToPgError(msg)
|
||||
p.state.HandleError(pgErr)
|
||||
return nil, pgErr
|
||||
case *pgproto3.CommandComplete:
|
||||
p.conn.asyncClose()
|
||||
|
@ -2250,7 +2400,7 @@ func (p *Pipeline) Close() error {
|
|||
|
||||
p.closed = true
|
||||
|
||||
if p.pendingSync {
|
||||
if p.state.PendingSync() {
|
||||
p.conn.asyncClose()
|
||||
p.err = errors.New("pipeline has unsynced requests")
|
||||
p.conn.contextWatcher.Unwatch()
|
||||
|
@ -2259,7 +2409,7 @@ func (p *Pipeline) Close() error {
|
|||
return p.err
|
||||
}
|
||||
|
||||
for p.expectedReadyForQueryCount > 0 {
|
||||
for p.state.ExpectedReadyForQuery() > 0 {
|
||||
_, err := p.getResults()
|
||||
if err != nil {
|
||||
p.err = err
|
||||
|
|
|
@ -14,6 +14,7 @@ import (
|
|||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -1420,6 +1421,52 @@ func TestConnExecBatch(t *testing.T) {
|
|||
assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
|
||||
}
|
||||
|
||||
type mockConnection struct {
|
||||
net.Conn
|
||||
writeLatency *time.Duration
|
||||
}
|
||||
|
||||
func (m mockConnection) Write(b []byte) (n int, err error) {
|
||||
time.Sleep(*m.writeLatency)
|
||||
return m.Conn.Write(b)
|
||||
}
|
||||
|
||||
func TestConnExecBatchWriteError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
|
||||
var mockConn mockConnection
|
||||
writeLatency := 0 * time.Second
|
||||
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
conn, err := net.Dial(network, address)
|
||||
mockConn = mockConnection{conn, &writeLatency}
|
||||
return mockConn, err
|
||||
}
|
||||
|
||||
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
batch := &pgconn.Batch{}
|
||||
pgConn.Conn()
|
||||
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
|
||||
writeLatency = 2 * time.Second
|
||||
mrr := pgConn.ExecBatch(ctx2, batch)
|
||||
err = mrr.Close()
|
||||
require.Error(t, err)
|
||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
require.True(t, pgConn.IsClosed())
|
||||
}
|
||||
|
||||
func TestConnExecBatchDeferredError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -3105,6 +3152,344 @@ func TestPipelineQueryErrorBetweenSyncs(t *testing.T) {
|
|||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelineFlushForSingleRequests(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
pipeline := pgConn.StartPipeline(ctx)
|
||||
|
||||
pipeline.SendPrepare("ps", "select $1::text as msg", nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
sd, ok := results.(*pgconn.StatementDescription)
|
||||
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
||||
require.Len(t, sd.Fields, 1)
|
||||
require.Equal(t, "msg", string(sd.Fields[0].Name))
|
||||
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok := results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult := rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "hello", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendDeallocate("ps")
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.CloseComplete)
|
||||
require.Truef(t, ok, "expected CloseComplete, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
err = pipeline.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelineFlushForRequestSeries(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
pipeline := pgConn.StartPipeline(ctx)
|
||||
pipeline.SendPrepare("ps", "select $1::bigint as num", nil)
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
sd, ok := results.(*pgconn.StatementDescription)
|
||||
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
||||
require.Len(t, sd.Fields, 1)
|
||||
require.Equal(t, "num", string(sd.Fields[0].Name))
|
||||
require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("1")}, nil, nil)
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("2")}, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok := results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult := rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "2", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("3")}, nil, nil)
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("4")}, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "3", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "4", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("5")}, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("6")}, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "5", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "6", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
err = pipeline.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelineFlushWithError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
pipeline := pgConn.StartPipeline(ctx)
|
||||
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil)
|
||||
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err := pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok := results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult := rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
var pgErr *pgconn.PgError
|
||||
require.ErrorAs(t, readResult.Err, &pgErr)
|
||||
require.Equal(t, "22012", pgErr.Code)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
||||
pipeline.SendPipelineSync()
|
||||
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
||||
pipeline.SendFlushRequest()
|
||||
err = pipeline.Flush()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
rr, ok = results.(*pgconn.ResultReader)
|
||||
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
||||
readResult = rr.Read()
|
||||
require.NoError(t, readResult.Err)
|
||||
require.Len(t, readResult.Rows, 1)
|
||||
require.Len(t, readResult.Rows[0], 1)
|
||||
require.Equal(t, "5", string(readResult.Rows[0][0]))
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, results)
|
||||
|
||||
err = pipeline.Sync()
|
||||
require.NoError(t, err)
|
||||
|
||||
results, err = pipeline.GetResults()
|
||||
require.NoError(t, err)
|
||||
_, ok = results.(*pgconn.PipelineSync)
|
||||
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
||||
|
||||
err = pipeline.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestPipelineCloseReadsUnreadResults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -3435,6 +3820,173 @@ func TestSNISupport(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConnectWithDirectSSLNegotiation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
connString string
|
||||
expectDirectNego bool
|
||||
}{
|
||||
{
|
||||
name: "Default negotiation (postgres)",
|
||||
connString: "sslmode=require",
|
||||
expectDirectNego: false,
|
||||
},
|
||||
{
|
||||
name: "Direct negotiation",
|
||||
connString: "sslmode=require sslnegotiation=direct",
|
||||
expectDirectNego: true,
|
||||
},
|
||||
{
|
||||
name: "Explicit postgres negotiation",
|
||||
connString: "sslmode=require sslnegotiation=postgres",
|
||||
expectDirectNego: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
script := &pgmock.Script{
|
||||
Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
_, port, err := net.SplitHostPort(ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
var directNegoObserved atomic.Bool
|
||||
|
||||
serverErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(serverErrCh)
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
serverErrCh <- fmt.Errorf("accept error: %w", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
|
||||
firstByte := make([]byte, 1)
|
||||
_, err = conn.Read(firstByte)
|
||||
if err != nil {
|
||||
serverErrCh <- fmt.Errorf("read first byte error: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if TLS Client Hello (direct) or PostgreSQL SSLRequest
|
||||
isDirect := firstByte[0] >= 20 && firstByte[0] <= 23
|
||||
directNegoObserved.Store(isDirect)
|
||||
|
||||
var tlsConn *tls.Conn
|
||||
|
||||
if !isDirect {
|
||||
// Handle standard PostgreSQL SSL negotiation
|
||||
// Read the rest of the SSL request message
|
||||
sslRequestRemainder := make([]byte, 7)
|
||||
_, err = io.ReadFull(conn, sslRequestRemainder)
|
||||
if err != nil {
|
||||
serverErrCh <- fmt.Errorf("read ssl request remainder error: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Send SSL acceptance response
|
||||
_, err = conn.Write([]byte("S"))
|
||||
if err != nil {
|
||||
serverErrCh <- fmt.Errorf("write ssl acceptance error: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Setup TLS server without needing to reuse the first byte
|
||||
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
|
||||
if err != nil {
|
||||
serverErrCh <- fmt.Errorf("cert error: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
tlsConn = tls.Server(conn, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
} else {
|
||||
// Handle direct TLS negotiation
|
||||
// Setup TLS server with the first byte already read
|
||||
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
|
||||
if err != nil {
|
||||
serverErrCh <- fmt.Errorf("cert error: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Use a wrapper to inject the first byte back into the TLS handshake
|
||||
bufConn := &prefixConn{
|
||||
Conn: conn,
|
||||
prefixData: firstByte,
|
||||
}
|
||||
|
||||
tlsConn = tls.Server(bufConn, &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
})
|
||||
}
|
||||
|
||||
// Complete TLS handshake
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
serverErrCh <- fmt.Errorf("TLS handshake error: %w", err)
|
||||
return
|
||||
}
|
||||
defer tlsConn.Close()
|
||||
|
||||
err = script.Run(pgproto3.NewBackend(tlsConn, tlsConn))
|
||||
if err != nil {
|
||||
serverErrCh <- fmt.Errorf("pgmock run error: %w", err)
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
connStr := fmt.Sprintf("%s host=localhost port=%s sslmode=require sslinsecure=1",
|
||||
tt.connString, port)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := pgconn.Connect(ctx, connStr)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
defer conn.Close(ctx)
|
||||
|
||||
err = <-serverErrCh
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, tt.expectDirectNego, directNegoObserved.Load())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// prefixConn implements a net.Conn that prepends some data to the first Read
|
||||
type prefixConn struct {
|
||||
net.Conn
|
||||
prefixData []byte
|
||||
prefixConsumed bool
|
||||
}
|
||||
|
||||
func (c *prefixConn) Read(b []byte) (n int, err error) {
|
||||
if !c.prefixConsumed && len(c.prefixData) > 0 {
|
||||
n = copy(b, c.prefixData)
|
||||
c.prefixData = c.prefixData[n:]
|
||||
c.prefixConsumed = len(c.prefixData) == 0
|
||||
return n, nil
|
||||
}
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1920
|
||||
func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
|
|
@ -12,7 +12,7 @@ type PasswordMessage struct {
|
|||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*PasswordMessage) Frontend() {}
|
||||
|
||||
// Frontend identifies this message as an authentication response.
|
||||
// InitialResponse identifies this message as an authentication response.
|
||||
func (*PasswordMessage) InitialResponse() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
|
|
|
@ -11,8 +11,6 @@ import (
|
|||
)
|
||||
|
||||
func TestCompositeCodecTranscode(t *testing.T) {
|
||||
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
|
||||
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
|
||||
_, err := conn.Exec(ctx, `drop type if exists ct_test;
|
||||
|
@ -91,8 +89,6 @@ func (p *point3d) ScanIndex(i int) any {
|
|||
}
|
||||
|
||||
func TestCompositeCodecTranscodeStruct(t *testing.T) {
|
||||
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
|
||||
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
|
||||
_, err := conn.Exec(ctx, `drop type if exists point3d;
|
||||
|
@ -128,8 +124,6 @@ create type point3d as (
|
|||
}
|
||||
|
||||
func TestCompositeCodecTranscodeStructWrapper(t *testing.T) {
|
||||
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
|
||||
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
|
||||
_, err := conn.Exec(ctx, `drop type if exists point3d;
|
||||
|
@ -169,8 +163,6 @@ create type point3d as (
|
|||
}
|
||||
|
||||
func TestCompositeCodecDecodeValue(t *testing.T) {
|
||||
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
|
||||
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
|
||||
_, err := conn.Exec(ctx, `drop type if exists point3d;
|
||||
|
@ -214,7 +206,7 @@ create type point3d as (
|
|||
//
|
||||
// https://github.com/jackc/pgx/issues/1576
|
||||
func TestCompositeCodecTranscodeStructWrapperForTable(t *testing.T) {
|
||||
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
|
||||
skipCockroachDB(t, "Server does not support composite types from table definitions")
|
||||
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
|
||||
|
|
|
@ -53,8 +53,8 @@ similar fashion to database/sql. The second is to use a pointer to a pointer.
|
|||
return err
|
||||
}
|
||||
|
||||
When using nullable pgtype types as parameters for queries, one has to remember
|
||||
to explicitly set their Valid field to true, otherwise the parameter's value will be NULL.
|
||||
When using nullable pgtype types as parameters for queries, one has to remember to explicitly set their Valid field to
|
||||
true, otherwise the parameter's value will be NULL.
|
||||
|
||||
JSON Support
|
||||
|
||||
|
@ -159,11 +159,16 @@ example_child_records_test.go for an example.
|
|||
|
||||
Overview of Scanning Implementation
|
||||
|
||||
The first step is to use the OID to lookup the correct Codec. If the OID is unavailable, Map will try to find the OID
|
||||
from previous calls of Map.RegisterDefaultPgType. The Map will call the Codec's PlanScan method to get a plan for
|
||||
scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types are
|
||||
interfaces rather than explicit types. For example, PointCodec can use any Go type that implements the PointScanner and
|
||||
PointValuer interfaces.
|
||||
The first step is to use the OID to lookup the correct Codec. The Map will call the Codec's PlanScan method to get a
|
||||
plan for scanning into the Go value. A Codec will support scanning into one or more Go types. Oftentime these Go types
|
||||
are interfaces rather than explicit types. For example, PointCodec can use any Go type that implements the PointScanner
|
||||
and PointValuer interfaces.
|
||||
|
||||
If a Go value is not supported directly by a Codec then Map will try see if it is a sql.Scanner. If is then that
|
||||
interface will be used to scan the value. Most sql.Scanners require the input to be in the text format (e.g. UUIDs and
|
||||
numeric). However, pgx will typically have received the value in the binary format. In this case the binary value will be
|
||||
parsed, reencoded as text, and then passed to the sql.Scanner. This may incur additional overhead for query results with
|
||||
a large number of affected values.
|
||||
|
||||
If a Go value is not supported directly by a Codec then Map will try wrapping it with additional logic and try again.
|
||||
For example, Int8Codec does not support scanning into a renamed type (e.g. type myInt64 int64). But Map will detect that
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Do not edit. Generated from pgtype/int.go.erb
|
||||
// Code generated from pgtype/int.go.erb. DO NOT EDIT.
|
||||
|
||||
package pgtype
|
||||
|
||||
import (
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Do not edit. Generated from pgtype/int_test.go.erb
|
||||
// Code generated from pgtype/int_test.go.erb. DO NOT EDIT.
|
||||
|
||||
package pgtype_test
|
||||
|
||||
import (
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
// Code generated from pgtype/integration_benchmark_test.go.erb. DO NOT EDIT.
|
||||
|
||||
package pgtype_test
|
||||
|
||||
import (
|
||||
|
|
104
pgtype/json.go
104
pgtype/json.go
|
@ -71,6 +71,27 @@ func (c *JSONCodec) PlanEncode(m *Map, oid uint32, format int16, value any) Enco
|
|||
}
|
||||
}
|
||||
|
||||
// JSON needs its on scan plan for pointers to handle 'null'::json(b).
|
||||
// Consider making pointerPointerScanPlan more flexible in the future.
|
||||
type jsonPointerScanPlan struct {
|
||||
next ScanPlan
|
||||
}
|
||||
|
||||
func (p jsonPointerScanPlan) Scan(src []byte, dst any) error {
|
||||
el := reflect.ValueOf(dst).Elem()
|
||||
if src == nil || string(src) == "null" {
|
||||
el.SetZero()
|
||||
return nil
|
||||
}
|
||||
|
||||
el.Set(reflect.New(el.Type().Elem()))
|
||||
if p.next != nil {
|
||||
return p.next.Scan(src, el.Interface())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type encodePlanJSONCodecEitherFormatString struct{}
|
||||
|
||||
func (encodePlanJSONCodecEitherFormatString) Encode(value any, buf []byte) (newBuf []byte, err error) {
|
||||
|
@ -117,58 +138,36 @@ func (e *encodePlanJSONCodecEitherFormatMarshal) Encode(value any, buf []byte) (
|
|||
return buf, nil
|
||||
}
|
||||
|
||||
func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||
switch target.(type) {
|
||||
case *string:
|
||||
return scanPlanAnyToString{}
|
||||
|
||||
case **string:
|
||||
// This is to fix **string scanning. It seems wrong to special case **string, but it's not clear what a better
|
||||
// solution would be.
|
||||
//
|
||||
// https://github.com/jackc/pgx/issues/1470 -- **string
|
||||
// https://github.com/jackc/pgx/issues/1691 -- ** anything else
|
||||
|
||||
if wrapperPlan, nextDst, ok := TryPointerPointerScanPlan(target); ok {
|
||||
if nextPlan := m.planScan(oid, format, nextDst, 0); nextPlan != nil {
|
||||
if _, failed := nextPlan.(*scanPlanFail); !failed {
|
||||
wrapperPlan.SetNext(nextPlan)
|
||||
return wrapperPlan
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case *[]byte:
|
||||
return scanPlanJSONToByteSlice{}
|
||||
case BytesScanner:
|
||||
return scanPlanBinaryBytesToBytesScanner{}
|
||||
|
||||
}
|
||||
|
||||
// Cannot rely on sql.Scanner being handled later because scanPlanJSONToJSONUnmarshal will take precedence.
|
||||
//
|
||||
// https://github.com/jackc/pgx/issues/1418
|
||||
if isSQLScanner(target) {
|
||||
return &scanPlanSQLScanner{formatCode: format}
|
||||
}
|
||||
|
||||
return &scanPlanJSONToJSONUnmarshal{
|
||||
unmarshal: c.Unmarshal,
|
||||
}
|
||||
func (c *JSONCodec) PlanScan(m *Map, oid uint32, formatCode int16, target any) ScanPlan {
|
||||
return c.planScan(m, oid, formatCode, target, 0)
|
||||
}
|
||||
|
||||
// we need to check if the target is a pointer to a sql.Scanner (or any of the pointer ref tree implements a sql.Scanner).
|
||||
//
|
||||
// https://github.com/jackc/pgx/issues/2146
|
||||
func isSQLScanner(v any) bool {
|
||||
val := reflect.ValueOf(v)
|
||||
for val.Kind() == reflect.Ptr {
|
||||
if _, ok := val.Interface().(sql.Scanner); ok {
|
||||
return true
|
||||
}
|
||||
val = val.Elem()
|
||||
// JSON cannot fallback to pointerPointerScanPlan because of 'null'::json(b),
|
||||
// so we need to duplicate the logic here.
|
||||
func (c *JSONCodec) planScan(m *Map, oid uint32, formatCode int16, target any, depth int) ScanPlan {
|
||||
if depth > 8 {
|
||||
return &scanPlanFail{m: m, oid: oid, formatCode: formatCode}
|
||||
}
|
||||
|
||||
switch target.(type) {
|
||||
case *string:
|
||||
return &scanPlanAnyToString{}
|
||||
case *[]byte:
|
||||
return &scanPlanJSONToByteSlice{}
|
||||
case BytesScanner:
|
||||
return &scanPlanBinaryBytesToBytesScanner{}
|
||||
case sql.Scanner:
|
||||
return &scanPlanSQLScanner{formatCode: formatCode}
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(target)
|
||||
if rv.Kind() == reflect.Pointer && rv.Elem().Kind() == reflect.Pointer {
|
||||
var plan jsonPointerScanPlan
|
||||
plan.next = c.planScan(m, oid, formatCode, rv.Elem().Interface(), depth+1)
|
||||
return plan
|
||||
} else {
|
||||
return &scanPlanJSONToJSONUnmarshal{unmarshal: c.Unmarshal}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type scanPlanAnyToString struct{}
|
||||
|
@ -212,7 +211,12 @@ func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
|
|||
return fmt.Errorf("cannot scan NULL into %T", dst)
|
||||
}
|
||||
|
||||
elem := reflect.ValueOf(dst).Elem()
|
||||
v := reflect.ValueOf(dst)
|
||||
if v.Kind() != reflect.Pointer || v.IsNil() {
|
||||
return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst)
|
||||
}
|
||||
|
||||
elem := v.Elem()
|
||||
elem.Set(reflect.Zero(elem.Type()))
|
||||
|
||||
return s.unmarshal(src, dst)
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
|
@ -48,6 +49,7 @@ func TestJSONCodec(t *testing.T) {
|
|||
Age int `json:"age"`
|
||||
}
|
||||
|
||||
var str string
|
||||
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "json", []pgxtest.ValueRoundTripTest{
|
||||
{nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))},
|
||||
{map[string]any(nil), new(*string), isExpectedEq((*string)(nil))},
|
||||
|
@ -65,6 +67,9 @@ func TestJSONCodec(t *testing.T) {
|
|||
{Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))},
|
||||
// Test driver.Scanner is used before json.Unmarshaler (https://github.com/jackc/pgx/issues/2146)
|
||||
{Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))},
|
||||
|
||||
// Test driver.Scanner without pointer receiver (https://github.com/jackc/pgx/issues/2204)
|
||||
{NonPointerJSONScanner{V: stringPtr("{}")}, NonPointerJSONScanner{V: &str}, func(a any) bool { return str == "{}" }},
|
||||
})
|
||||
|
||||
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
||||
|
@ -136,6 +141,27 @@ func (i Issue2146) Value() (driver.Value, error) {
|
|||
return string(b), err
|
||||
}
|
||||
|
||||
type NonPointerJSONScanner struct {
|
||||
V *string
|
||||
}
|
||||
|
||||
func (i NonPointerJSONScanner) Scan(src any) error {
|
||||
switch c := src.(type) {
|
||||
case string:
|
||||
*i.V = c
|
||||
case []byte:
|
||||
*i.V = string(c)
|
||||
default:
|
||||
return errors.New("unknown source type")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i NonPointerJSONScanner) Value() (driver.Value, error) {
|
||||
return i.V, nil
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648
|
||||
func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
|
@ -166,11 +192,15 @@ func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
|
|||
// A string cannot scan a NULL.
|
||||
str := "foobar"
|
||||
err = conn.QueryRow(ctx, "select null::json").Scan(&str)
|
||||
require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string")
|
||||
fieldName := "json"
|
||||
if conn.PgConn().ParameterStatus("crdb_version") != "" {
|
||||
fieldName = "jsonb" // Seems like CockroachDB treats json as jsonb.
|
||||
}
|
||||
require.EqualError(t, err, fmt.Sprintf("can't scan into dest[0] (col: %s): cannot scan NULL into *string", fieldName))
|
||||
|
||||
// A non-string cannot scan a NULL.
|
||||
err = conn.QueryRow(ctx, "select null::json").Scan(&n)
|
||||
require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *int")
|
||||
require.EqualError(t, err, fmt.Sprintf("can't scan into dest[0] (col: %s): cannot scan NULL into *int", fieldName))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -267,7 +297,8 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
|
|||
Unmarshal: func(data []byte, v any) error {
|
||||
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
|
||||
},
|
||||
}})
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
||||
|
@ -278,3 +309,54 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
|
|||
}},
|
||||
})
|
||||
}
|
||||
|
||||
func TestJSONCodecScanToNonPointerValues(t *testing.T) {
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
n := 44
|
||||
err := conn.QueryRow(ctx, "select '42'::jsonb").Scan(n)
|
||||
require.Error(t, err)
|
||||
|
||||
var i *int
|
||||
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(i)
|
||||
require.Error(t, err)
|
||||
|
||||
m := 0
|
||||
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(&m)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 42, m)
|
||||
})
|
||||
}
|
||||
|
||||
func TestJSONCodecScanNull(t *testing.T) {
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
var dest struct{}
|
||||
err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "cannot scan NULL into *struct {}")
|
||||
|
||||
err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&dest)
|
||||
require.NoError(t, err)
|
||||
|
||||
var destPointer *struct{}
|
||||
err = conn.QueryRow(ctx, "select null::jsonb").Scan(&destPointer)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, destPointer)
|
||||
|
||||
err = conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&destPointer)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, destPointer)
|
||||
|
||||
var raw json.RawMessage
|
||||
require.NoError(t, conn.QueryRow(ctx, "select 'null'::jsonb").Scan(&raw))
|
||||
require.Equal(t, json.RawMessage("null"), raw)
|
||||
})
|
||||
}
|
||||
|
||||
func TestJSONCodecScanNullToPointerToSQLScanner(t *testing.T) {
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
var dest *Issue2146
|
||||
err := conn.QueryRow(ctx, "select null::jsonb").Scan(&dest)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, dest)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -66,11 +66,11 @@ func TestJSONBCodecUnmarshalSQLNull(t *testing.T) {
|
|||
// A string cannot scan a NULL.
|
||||
str := "foobar"
|
||||
err = conn.QueryRow(ctx, "select null::jsonb").Scan(&str)
|
||||
require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string")
|
||||
require.EqualError(t, err, "can't scan into dest[0] (col: jsonb): cannot scan NULL into *string")
|
||||
|
||||
// A non-string cannot scan a NULL.
|
||||
err = conn.QueryRow(ctx, "select null::jsonb").Scan(&n)
|
||||
require.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *int")
|
||||
require.EqualError(t, err, "can't scan into dest[0] (col: jsonb): cannot scan NULL into *int")
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -396,11 +396,7 @@ type scanPlanSQLScanner struct {
|
|||
}
|
||||
|
||||
func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
|
||||
scanner := getSQLScanner(dst)
|
||||
|
||||
if scanner == nil {
|
||||
return fmt.Errorf("cannot scan into %T", dst)
|
||||
}
|
||||
scanner := dst.(sql.Scanner)
|
||||
|
||||
if src == nil {
|
||||
// This is necessary because interface value []byte:nil does not equal nil:nil for the binary format path and the
|
||||
|
@ -413,21 +409,6 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
|
|||
}
|
||||
}
|
||||
|
||||
// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
|
||||
func getSQLScanner(target any) sql.Scanner {
|
||||
val := reflect.ValueOf(target)
|
||||
for val.Kind() == reflect.Ptr {
|
||||
if _, ok := val.Interface().(sql.Scanner); ok {
|
||||
if val.IsNil() {
|
||||
val.Set(reflect.New(val.Type().Elem()))
|
||||
}
|
||||
return val.Interface().(sql.Scanner)
|
||||
}
|
||||
val = val.Elem()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type scanPlanString struct{}
|
||||
|
||||
func (scanPlanString) Scan(src []byte, dst any) error {
|
||||
|
|
|
@ -91,7 +91,25 @@ func initDefaultMap() {
|
|||
defaultMap.RegisterType(&Type{Name: "varchar", OID: VarcharOID, Codec: TextCodec{}})
|
||||
defaultMap.RegisterType(&Type{Name: "xid", OID: XIDOID, Codec: Uint32Codec{}})
|
||||
defaultMap.RegisterType(&Type{Name: "xid8", OID: XID8OID, Codec: Uint64Codec{}})
|
||||
defaultMap.RegisterType(&Type{Name: "xml", OID: XMLOID, Codec: &XMLCodec{Marshal: xml.Marshal, Unmarshal: xml.Unmarshal}})
|
||||
defaultMap.RegisterType(&Type{Name: "xml", OID: XMLOID, Codec: &XMLCodec{
|
||||
Marshal: xml.Marshal,
|
||||
// xml.Unmarshal does not support unmarshalling into *any. However, XMLCodec.DecodeValue calls Unmarshal with a
|
||||
// *any. Wrap xml.Marshal with a function that copies the data into a new byte slice in this case. Not implementing
|
||||
// directly in XMLCodec.DecodeValue to allow for the unlikely possibility that someone uses an alternative XML
|
||||
// unmarshaler that does support unmarshalling into *any.
|
||||
//
|
||||
// https://github.com/jackc/pgx/issues/2227
|
||||
// https://github.com/jackc/pgx/pull/2228
|
||||
Unmarshal: func(data []byte, v any) error {
|
||||
if v, ok := v.(*any); ok {
|
||||
dstBuf := make([]byte, len(data))
|
||||
copy(dstBuf, data)
|
||||
*v = dstBuf
|
||||
return nil
|
||||
}
|
||||
return xml.Unmarshal(data, v)
|
||||
},
|
||||
}})
|
||||
|
||||
// Range types
|
||||
defaultMap.RegisterType(&Type{Name: "daterange", OID: DaterangeOID, Codec: &RangeCodec{ElementType: defaultMap.oidToType[DateOID]}})
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
)
|
||||
|
||||
const pgTimestampFormat = "2006-01-02 15:04:05.999999999"
|
||||
const jsonISO8601 = "2006-01-02T15:04:05.999999999"
|
||||
|
||||
type TimestampScanner interface {
|
||||
ScanTimestamp(v Timestamp) error
|
||||
|
@ -76,7 +77,7 @@ func (ts Timestamp) MarshalJSON() ([]byte, error) {
|
|||
|
||||
switch ts.InfinityModifier {
|
||||
case Finite:
|
||||
s = ts.Time.Format(time.RFC3339Nano)
|
||||
s = ts.Time.Format(jsonISO8601)
|
||||
case Infinity:
|
||||
s = "infinity"
|
||||
case NegativeInfinity:
|
||||
|
@ -104,15 +105,23 @@ func (ts *Timestamp) UnmarshalJSON(b []byte) error {
|
|||
case "-infinity":
|
||||
*ts = Timestamp{Valid: true, InfinityModifier: -Infinity}
|
||||
default:
|
||||
// PostgreSQL uses ISO 8601 wihout timezone for to_json function and casting from a string to timestampt
|
||||
tim, err := time.Parse(time.RFC3339Nano, *s+"Z")
|
||||
if err != nil {
|
||||
return err
|
||||
// Parse time with or without timezonr
|
||||
tss := *s
|
||||
// PostgreSQL uses ISO 8601 without timezone for to_json function and casting from a string to timestampt
|
||||
tim, err := time.Parse(time.RFC3339Nano, tss)
|
||||
if err == nil {
|
||||
*ts = Timestamp{Time: tim, Valid: true}
|
||||
return nil
|
||||
}
|
||||
|
||||
*ts = Timestamp{Time: tim, Valid: true}
|
||||
tim, err = time.ParseInLocation(jsonISO8601, tss, time.UTC)
|
||||
if err == nil {
|
||||
*ts = Timestamp{Time: tim, Valid: true}
|
||||
return nil
|
||||
}
|
||||
ts.Valid = false
|
||||
return fmt.Errorf("cannot unmarshal %s to timestamp with layout %s or %s (%w)",
|
||||
*s, time.RFC3339Nano, jsonISO8601, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -2,12 +2,14 @@ package pgtype_test
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pgx "github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxtest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -100,13 +102,24 @@ func TestTimestampCodecDecodeTextInvalid(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTimestampMarshalJSON(t *testing.T) {
|
||||
|
||||
tsStruct := struct {
|
||||
TS pgtype.Timestamp `json:"ts"`
|
||||
}{}
|
||||
|
||||
tm := time.Date(2012, 3, 29, 10, 5, 45, 0, time.UTC)
|
||||
tsString := "\"" + tm.Format("2006-01-02T15:04:05") + "\"" // `"2012-03-29T10:05:45"`
|
||||
var pgt pgtype.Timestamp
|
||||
_ = pgt.Scan(tm)
|
||||
|
||||
successfulTests := []struct {
|
||||
source pgtype.Timestamp
|
||||
result string
|
||||
}{
|
||||
{source: pgtype.Timestamp{}, result: "null"},
|
||||
{source: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 0, time.UTC), Valid: true}, result: "\"2012-03-29T10:05:45Z\""},
|
||||
{source: pgtype.Timestamp{Time: time.Date(2012, 3, 29, 10, 5, 45, 555*1000*1000, time.UTC), Valid: true}, result: "\"2012-03-29T10:05:45.555Z\""},
|
||||
{source: pgtype.Timestamp{Time: tm, Valid: true}, result: tsString},
|
||||
{source: pgt, result: tsString},
|
||||
{source: pgtype.Timestamp{Time: tm.Add(time.Second * 555 / 1000), Valid: true}, result: `"2012-03-29T10:05:45.555"`},
|
||||
{source: pgtype.Timestamp{InfinityModifier: pgtype.Infinity, Valid: true}, result: "\"infinity\""},
|
||||
{source: pgtype.Timestamp{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "\"-infinity\""},
|
||||
}
|
||||
|
@ -116,12 +129,32 @@ func TestTimestampMarshalJSON(t *testing.T) {
|
|||
t.Errorf("%d: %v", i, err)
|
||||
}
|
||||
|
||||
if string(r) != tt.result {
|
||||
if !assert.Equal(t, tt.result, string(r)) {
|
||||
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, string(r))
|
||||
}
|
||||
tsStruct.TS = tt.source
|
||||
b, err := json.Marshal(tsStruct)
|
||||
assert.NoErrorf(t, err, "failed to marshal %v %s", tt.source, err)
|
||||
t2 := tsStruct
|
||||
t2.TS = pgtype.Timestamp{} // Clear out the value so that we can compare after unmarshalling
|
||||
err = json.Unmarshal(b, &t2)
|
||||
assert.NoErrorf(t, err, "failed to unmarshal %v with %s", tt.source, err)
|
||||
assert.True(t, tsStruct.TS.Time.Unix() == t2.TS.Time.Unix())
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimestampUnmarshalJSONErrors(t *testing.T) {
|
||||
tsStruct := struct {
|
||||
TS pgtype.Timestamp `json:"ts"`
|
||||
}{}
|
||||
goodJson1 := []byte(`{"ts":"2012-03-29T10:05:45"}`)
|
||||
assert.NoError(t, json.Unmarshal(goodJson1, &tsStruct))
|
||||
goodJson2 := []byte(`{"ts":"2012-03-29T10:05:45Z"}`)
|
||||
assert.NoError(t, json.Unmarshal(goodJson2, &tsStruct))
|
||||
badJson := []byte(`{"ts":"2012-03-29"}`)
|
||||
assert.Error(t, json.Unmarshal(badJson, &tsStruct))
|
||||
}
|
||||
|
||||
func TestTimestampUnmarshalJSON(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
source string
|
||||
|
|
|
@ -79,7 +79,7 @@ func TestXMLCodecUnmarshalSQLNull(t *testing.T) {
|
|||
// A string cannot scan a NULL.
|
||||
str := "foobar"
|
||||
err = conn.QueryRow(ctx, "select null::xml").Scan(&str)
|
||||
assert.EqualError(t, err, "can't scan into dest[0]: cannot scan NULL into *string")
|
||||
assert.EqualError(t, err, "can't scan into dest[0] (col: xml): cannot scan NULL into *string")
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -97,3 +97,32 @@ func TestXMLCodecPointerToPointerToString(t *testing.T) {
|
|||
require.Nil(t, s)
|
||||
})
|
||||
}
|
||||
|
||||
func TestXMLCodecDecodeValue(t *testing.T) {
|
||||
skipCockroachDB(t, "CockroachDB does not support XML.")
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) {
|
||||
for _, tt := range []struct {
|
||||
sql string
|
||||
expected any
|
||||
}{
|
||||
{
|
||||
sql: `select '<foo>bar</foo>'::xml`,
|
||||
expected: []byte("<foo>bar</foo>"),
|
||||
},
|
||||
} {
|
||||
t.Run(tt.sql, func(t *testing.T) {
|
||||
rows, err := conn.Query(ctx, tt.sql)
|
||||
require.NoError(t, err)
|
||||
|
||||
for rows.Next() {
|
||||
values, err := rows.Values()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, values, 1)
|
||||
require.Equal(t, tt.expected, values[0])
|
||||
}
|
||||
|
||||
require.NoError(t, rows.Err())
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Do not edit. Generated from pgtype/zeronull/int.go.erb
|
||||
// Code generated from pgtype/zeronull/int.go.erb. DO NOT EDIT.
|
||||
|
||||
package zeronull
|
||||
|
||||
import (
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// Do not edit. Generated from pgtype/zeronull/int_test.go.erb
|
||||
// Code generated from pgtype/zeronull/int_test.go.erb. DO NOT EDIT.
|
||||
|
||||
package zeronull_test
|
||||
|
||||
import (
|
||||
|
|
|
@ -147,6 +147,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName
|
|||
assert.Equalf(t, expected.MaxConnIdleTime, actual.MaxConnIdleTime, "%s - MaxConnIdleTime", testName)
|
||||
assert.Equalf(t, expected.MaxConns, actual.MaxConns, "%s - MaxConns", testName)
|
||||
assert.Equalf(t, expected.MinConns, actual.MinConns, "%s - MinConns", testName)
|
||||
assert.Equalf(t, expected.MinIdleConns, actual.MinIdleConns, "%s - MinIdleConns", testName)
|
||||
assert.Equalf(t, expected.HealthCheckPeriod, actual.HealthCheckPeriod, "%s - HealthCheckPeriod", testName)
|
||||
|
||||
assertConnConfigsEqual(t, expected.ConnConfig, actual.ConnConfig, testName)
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
|
||||
var defaultMaxConns = int32(4)
|
||||
var defaultMinConns = int32(0)
|
||||
var defaultMinIdleConns = int32(0)
|
||||
var defaultMaxConnLifetime = time.Hour
|
||||
var defaultMaxConnIdleTime = time.Minute * 30
|
||||
var defaultHealthCheckPeriod = time.Minute
|
||||
|
@ -87,6 +88,7 @@ type Pool struct {
|
|||
afterRelease func(*pgx.Conn) bool
|
||||
beforeClose func(*pgx.Conn)
|
||||
minConns int32
|
||||
minIdleConns int32
|
||||
maxConns int32
|
||||
maxConnLifetime time.Duration
|
||||
maxConnLifetimeJitter time.Duration
|
||||
|
@ -144,6 +146,13 @@ type Config struct {
|
|||
// to create new connections.
|
||||
MinConns int32
|
||||
|
||||
// MinIdleConns is the minimum number of idle connections in the pool. You can increase this to ensure that
|
||||
// there are always idle connections available. This can help reduce tail latencies during request processing,
|
||||
// as you can avoid the latency of establishing a new connection while handling requests. It is superior
|
||||
// to MinConns for this purpose.
|
||||
// Similar to MinConns, the pool might temporarily dip below MinIdleConns after connection closes.
|
||||
MinIdleConns int32
|
||||
|
||||
// HealthCheckPeriod is the duration between checks of the health of idle connections.
|
||||
HealthCheckPeriod time.Duration
|
||||
|
||||
|
@ -189,6 +198,7 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
|
|||
afterRelease: config.AfterRelease,
|
||||
beforeClose: config.BeforeClose,
|
||||
minConns: config.MinConns,
|
||||
minIdleConns: config.MinIdleConns,
|
||||
maxConns: config.MaxConns,
|
||||
maxConnLifetime: config.MaxConnLifetime,
|
||||
maxConnLifetimeJitter: config.MaxConnLifetimeJitter,
|
||||
|
@ -271,7 +281,8 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
|
|||
}
|
||||
|
||||
go func() {
|
||||
p.createIdleResources(ctx, int(p.minConns))
|
||||
targetIdleResources := max(int(p.minConns), int(p.minIdleConns))
|
||||
p.createIdleResources(ctx, targetIdleResources)
|
||||
p.backgroundHealthCheck()
|
||||
}()
|
||||
|
||||
|
@ -334,6 +345,17 @@ func ParseConfig(connString string) (*Config, error) {
|
|||
config.MinConns = defaultMinConns
|
||||
}
|
||||
|
||||
if s, ok := config.ConnConfig.Config.RuntimeParams["pool_min_idle_conns"]; ok {
|
||||
delete(connConfig.Config.RuntimeParams, "pool_min_idle_conns")
|
||||
n, err := strconv.ParseInt(s, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse pool_min_idle_conns: %w", err)
|
||||
}
|
||||
config.MinIdleConns = int32(n)
|
||||
} else {
|
||||
config.MinIdleConns = defaultMinIdleConns
|
||||
}
|
||||
|
||||
if s, ok := config.ConnConfig.Config.RuntimeParams["pool_max_conn_lifetime"]; ok {
|
||||
delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime")
|
||||
d, err := time.ParseDuration(s)
|
||||
|
@ -472,7 +494,9 @@ func (p *Pool) checkMinConns() error {
|
|||
// TotalConns can include ones that are being destroyed but we should have
|
||||
// sleep(500ms) around all of the destroys to help prevent that from throwing
|
||||
// off this check
|
||||
toCreate := p.minConns - p.Stat().TotalConns()
|
||||
|
||||
// Create the number of connections needed to get to both minConns and minIdleConns
|
||||
toCreate := max(p.minConns-p.Stat().TotalConns(), p.minIdleConns-p.Stat().IdleConns())
|
||||
if toCreate > 0 {
|
||||
return p.createIdleResources(context.Background(), int(toCreate))
|
||||
}
|
||||
|
|
|
@ -43,10 +43,11 @@ func TestConnectConfig(t *testing.T) {
|
|||
func TestParseConfigExtractsPoolArguments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config, err := pgxpool.ParseConfig("pool_max_conns=42 pool_min_conns=1")
|
||||
config, err := pgxpool.ParseConfig("pool_max_conns=42 pool_min_conns=1 pool_min_idle_conns=2")
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, 42, config.MaxConns)
|
||||
assert.EqualValues(t, 1, config.MinConns)
|
||||
assert.EqualValues(t, 2, config.MinIdleConns)
|
||||
assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_max_conns")
|
||||
assert.NotContains(t, config.ConnConfig.Config.RuntimeParams, "pool_min_conns")
|
||||
}
|
||||
|
|
|
@ -82,3 +82,10 @@ func (s *Stat) MaxLifetimeDestroyCount() int64 {
|
|||
func (s *Stat) MaxIdleDestroyCount() int64 {
|
||||
return s.idleDestroyCount
|
||||
}
|
||||
|
||||
// EmptyAcquireWaitTime returns the cumulative time waited for successful acquires
|
||||
// from the pool for a resource to be released or constructed because the pool was
|
||||
// empty.
|
||||
func (s *Stat) EmptyAcquireWaitTime() time.Duration {
|
||||
return s.s.EmptyAcquireWaitTime()
|
||||
}
|
||||
|
|
|
@ -420,7 +420,7 @@ func TestConnQueryReadWrongTypeError(t *testing.T) {
|
|||
t.Fatal("Expected Rows to have an error after an improper read but it didn't")
|
||||
}
|
||||
|
||||
if rows.Err().Error() != "can't scan into dest[0]: cannot scan int4 (OID 23) in binary format into *time.Time" {
|
||||
if rows.Err().Error() != "can't scan into dest[0] (col: n): cannot scan int4 (OID 23) in binary format into *time.Time" {
|
||||
t.Fatalf("Expected different Rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
|
|
13
rows.go
13
rows.go
|
@ -272,7 +272,7 @@ func (rows *baseRows) Scan(dest ...any) error {
|
|||
|
||||
err := rows.scanPlans[i].Scan(values[i], dst)
|
||||
if err != nil {
|
||||
err = ScanArgError{ColumnIndex: i, Err: err}
|
||||
err = ScanArgError{ColumnIndex: i, FieldName: fieldDescriptions[i].Name, Err: err}
|
||||
rows.fatal(err)
|
||||
return err
|
||||
}
|
||||
|
@ -334,11 +334,16 @@ func (rows *baseRows) Conn() *Conn {
|
|||
|
||||
type ScanArgError struct {
|
||||
ColumnIndex int
|
||||
FieldName string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e ScanArgError) Error() string {
|
||||
return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err)
|
||||
if e.FieldName == "?column?" { // Don't include the fieldname if it's unknown
|
||||
return fmt.Sprintf("can't scan into dest[%d]: %v", e.ColumnIndex, e.Err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("can't scan into dest[%d] (col: %s): %v", e.ColumnIndex, e.FieldName, e.Err)
|
||||
}
|
||||
|
||||
func (e ScanArgError) Unwrap() error {
|
||||
|
@ -366,7 +371,7 @@ func ScanRow(typeMap *pgtype.Map, fieldDescriptions []pgconn.FieldDescription, v
|
|||
|
||||
err := typeMap.Scan(fieldDescriptions[i].DataTypeOID, fieldDescriptions[i].Format, values[i], d)
|
||||
if err != nil {
|
||||
return ScanArgError{ColumnIndex: i, Err: err}
|
||||
return ScanArgError{ColumnIndex: i, FieldName: fieldDescriptions[i].Name, Err: err}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -468,6 +473,8 @@ func CollectOneRow[T any](rows Rows, fn RowToFunc[T]) (T, error) {
|
|||
return value, err
|
||||
}
|
||||
|
||||
// The defer rows.Close() won't have executed yet. If the query returned more than one row, rows would still be open.
|
||||
// rows.Close() must be called before rows.Err() so we explicitly call it here.
|
||||
rows.Close()
|
||||
return value, rows.Err()
|
||||
}
|
||||
|
|
5
tx.go
5
tx.go
|
@ -3,7 +3,6 @@ package pgx
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
|
@ -103,7 +102,7 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) {
|
|||
if err != nil {
|
||||
// begin should never fail unless there is an underlying connection issue or
|
||||
// a context timeout. In either case, the connection is possibly broken.
|
||||
c.die(errors.New("failed to begin transaction"))
|
||||
c.die()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -216,7 +215,7 @@ func (tx *dbTx) Rollback(ctx context.Context) error {
|
|||
tx.closed = true
|
||||
if err != nil {
|
||||
// A rollback failure leaves the connection in an undefined state
|
||||
tx.conn.die(fmt.Errorf("rollback failed: %w", err))
|
||||
tx.conn.die()
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ package pgx_test
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
|
@ -215,7 +216,12 @@ func testJSONInt16ArrayFailureDueToOverflow(t *testing.T, conn *pgx.Conn, typena
|
|||
input := []int{1, 2, 234432}
|
||||
var output []int16
|
||||
err := conn.QueryRow(context.Background(), "select $1::"+typename, input).Scan(&output)
|
||||
if err == nil || err.Error() != "can't scan into dest[0]: json: cannot unmarshal number 234432 into Go value of type int16" {
|
||||
fieldName := typename
|
||||
if conn.PgConn().ParameterStatus("crdb_version") != "" && typename == "json" {
|
||||
fieldName = "jsonb" // Seems like CockroachDB treats json as jsonb.
|
||||
}
|
||||
expectedMessage := fmt.Sprintf("can't scan into dest[0] (col: %s): json: cannot unmarshal number 234432 into Go value of type int16", fieldName)
|
||||
if err == nil || err.Error() != expectedMessage {
|
||||
t.Errorf("%s: Expected *json.UnmarshalTypeError, but got %v", typename, err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue