mirror of
https://github.com/jackc/pgx.git
synced 2025-04-27 21:25:53 +00:00
Merge branch 'jackc:master' into master
This commit is contained in:
commit
4fa324bce8
12
CHANGELOG.md
12
CHANGELOG.md
@ -1,3 +1,15 @@
|
||||
# 5.7.2 (December 21, 2024)
|
||||
|
||||
* Fix prepared statement already exists on batch prepare failure
|
||||
* Add commit query to tx options (Lucas Hild)
|
||||
* Fix pgtype.Timestamp json unmarshal (Shean de Montigny-Desautels)
|
||||
* Add message body size limits in frontend and backend (zene)
|
||||
* Add xid8 type
|
||||
* Ensure planning encodes and scans cannot infinitely recurse
|
||||
* Implement pgtype.UUID.String() (Konstantin Grachev)
|
||||
* Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions (Alexander Rumyantsev)
|
||||
* Update golang.org/x/crypto
|
||||
|
||||
# 5.7.1 (September 10, 2024)
|
||||
|
||||
* Fix data race in tracelog.TraceLog
|
||||
|
12
README.md
12
README.md
@ -84,7 +84,7 @@ It is also possible to use the `database/sql` interface and convert a connection
|
||||
|
||||
## Testing
|
||||
|
||||
See CONTRIBUTING.md for setup instructions.
|
||||
See [CONTRIBUTING.md](./CONTRIBUTING.md) for setup instructions.
|
||||
|
||||
## Architecture
|
||||
|
||||
@ -126,7 +126,7 @@ pgerrcode contains constants for the PostgreSQL error codes.
|
||||
|
||||
## Adapters for 3rd Party Tracers
|
||||
|
||||
* [https://github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer)
|
||||
* [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer)
|
||||
|
||||
## Adapters for 3rd Party Loggers
|
||||
|
||||
@ -156,7 +156,7 @@ Library for scanning data from a database into Go structs and more.
|
||||
A carefully designed SQL client for making using SQL easier,
|
||||
more productive, and less error-prone on Golang.
|
||||
|
||||
### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
|
||||
### [github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
|
||||
|
||||
Adds GSSAPI / Kerberos authentication support.
|
||||
|
||||
@ -169,6 +169,10 @@ Explicit data mapping and scanning library for Go structs and slices.
|
||||
Type safe and flexible package for scanning database data into Go types.
|
||||
Supports, structs, maps, slices and custom mapping functions.
|
||||
|
||||
### [https://github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx)
|
||||
### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx)
|
||||
|
||||
Code first migration library for native pgx (no database/sql abstraction).
|
||||
|
||||
### [github.com/amirsalarsafaei/sqlc-pgx-monitoring](https://github.com/amirsalarsafaei/sqlc-pgx-monitoring)
|
||||
|
||||
A database monitoring/metrics library for pgx and sqlc. Trace, log and monitor your sqlc query performance using OpenTelemetry.
|
||||
|
2
Rakefile
2
Rakefile
@ -2,7 +2,7 @@ require "erb"
|
||||
|
||||
rule '.go' => '.go.erb' do |task|
|
||||
erb = ERB.new(File.read(task.source))
|
||||
File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding))
|
||||
File.write(task.name, "// Code generated from #{task.source}. DO NOT EDIT.\n\n" + erb.result(binding))
|
||||
sh "goimports", "-w", task.name
|
||||
end
|
||||
|
||||
|
10
conn.go
10
conn.go
@ -444,7 +444,7 @@ func (c *Conn) IsClosed() bool {
|
||||
return c.pgConn.IsClosed()
|
||||
}
|
||||
|
||||
func (c *Conn) die(err error) {
|
||||
func (c *Conn) die() {
|
||||
if c.IsClosed() {
|
||||
return
|
||||
}
|
||||
@ -612,14 +612,6 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription
|
||||
return result.CommandTag, result.Err
|
||||
}
|
||||
|
||||
type unknownArgumentTypeQueryExecModeExecError struct {
|
||||
arg any
|
||||
}
|
||||
|
||||
func (e *unknownArgumentTypeQueryExecModeExecError) Error() string {
|
||||
return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg)
|
||||
}
|
||||
|
||||
func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) {
|
||||
err := c.eqb.Build(c.typeMap, nil, args)
|
||||
if err != nil {
|
||||
|
@ -1417,5 +1417,4 @@ func TestErrNoRows(t *testing.T) {
|
||||
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())
|
||||
|
||||
require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
|
||||
require.ErrorIs(t, pgx.ErrNoRows, pgx.ErrNoRows, "sql.ErrNowRows must match pgx.ErrNoRows")
|
||||
}
|
||||
|
@ -161,7 +161,7 @@ type derivedTypeInfo struct {
|
||||
// The result of this call can be passed into RegisterTypes to complete the process.
|
||||
func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) {
|
||||
m := c.TypeMap()
|
||||
if typeNames == nil || len(typeNames) == 0 {
|
||||
if len(typeNames) == 0 {
|
||||
return nil, fmt.Errorf("No type names were supplied.")
|
||||
}
|
||||
|
||||
@ -169,13 +169,7 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ
|
||||
// the SQL not support recent structures such as multirange
|
||||
serverVersion, _ := serverVersion(c)
|
||||
sql := buildLoadDerivedTypesSQL(serverVersion, typeNames)
|
||||
var rows Rows
|
||||
var err error
|
||||
if typeNames == nil {
|
||||
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol)
|
||||
} else {
|
||||
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
|
||||
}
|
||||
rows, err := c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("While generating load types query: %w", err)
|
||||
}
|
||||
@ -232,15 +226,15 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName)
|
||||
}
|
||||
if type_ != nil {
|
||||
m.RegisterType(type_)
|
||||
if ti.NspName != "" {
|
||||
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}
|
||||
m.RegisterType(nspType)
|
||||
result = append(result, nspType)
|
||||
}
|
||||
result = append(result, type_)
|
||||
|
||||
// the type_ is imposible to be null
|
||||
m.RegisterType(type_)
|
||||
if ti.NspName != "" {
|
||||
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}
|
||||
m.RegisterType(nspType)
|
||||
result = append(result, nspType)
|
||||
}
|
||||
result = append(result, type_)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
60
internal/sanitize/benchmmark.sh
Normal file
60
internal/sanitize/benchmmark.sh
Normal file
@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
current_branch=$(git rev-parse --abbrev-ref HEAD)
|
||||
if [ "$current_branch" == "HEAD" ]; then
|
||||
current_branch=$(git rev-parse HEAD)
|
||||
fi
|
||||
|
||||
restore_branch() {
|
||||
echo "Restoring original branch/commit: $current_branch"
|
||||
git checkout "$current_branch"
|
||||
}
|
||||
trap restore_branch EXIT
|
||||
|
||||
# Check if there are uncommitted changes
|
||||
if ! git diff --quiet || ! git diff --cached --quiet; then
|
||||
echo "There are uncommitted changes. Please commit or stash them before running this script."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Ensure that at least one commit argument is passed
|
||||
if [ "$#" -lt 1 ]; then
|
||||
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
commits=("$@")
|
||||
benchmarks_dir=benchmarks
|
||||
|
||||
if ! mkdir -p "${benchmarks_dir}"; then
|
||||
echo "Unable to create dir for benchmarks data"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Benchmark results
|
||||
bench_files=()
|
||||
|
||||
# Run benchmark for each listed commit
|
||||
for i in "${!commits[@]}"; do
|
||||
commit="${commits[i]}"
|
||||
git checkout "$commit" || {
|
||||
echo "Failed to checkout $commit"
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Sanitized commmit message
|
||||
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
|
||||
|
||||
# Benchmark data will go there
|
||||
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
|
||||
|
||||
if ! go test -bench=. -count=10 >"$bench_file"; then
|
||||
echo "Benchmarking failed for commit $commit"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
bench_files+=("$bench_file")
|
||||
done
|
||||
|
||||
# go install golang.org/x/perf/cmd/benchstat[@latest]
|
||||
benchstat "${bench_files[@]}"
|
@ -4,8 +4,10 @@ import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
@ -24,18 +26,33 @@ type Query struct {
|
||||
// https://github.com/jackc/pgx/issues/1380
|
||||
const replacementcharacterwidth = 3
|
||||
|
||||
const maxBufSize = 16384 // 16 Ki
|
||||
|
||||
var bufPool = &pool[*bytes.Buffer]{
|
||||
new: func() *bytes.Buffer {
|
||||
return &bytes.Buffer{}
|
||||
},
|
||||
reset: func(b *bytes.Buffer) bool {
|
||||
n := b.Len()
|
||||
b.Reset()
|
||||
return n < maxBufSize
|
||||
},
|
||||
}
|
||||
|
||||
var null = []byte("null")
|
||||
|
||||
func (q *Query) Sanitize(args ...any) (string, error) {
|
||||
argUse := make([]bool, len(args))
|
||||
buf := &bytes.Buffer{}
|
||||
buf := bufPool.get()
|
||||
defer bufPool.put(buf)
|
||||
|
||||
for _, part := range q.Parts {
|
||||
var str string
|
||||
switch part := part.(type) {
|
||||
case string:
|
||||
str = part
|
||||
buf.WriteString(part)
|
||||
case int:
|
||||
argIdx := part - 1
|
||||
|
||||
var p []byte
|
||||
if argIdx < 0 {
|
||||
return "", fmt.Errorf("first sql argument must be > 0")
|
||||
}
|
||||
@ -43,34 +60,41 @@ func (q *Query) Sanitize(args ...any) (string, error) {
|
||||
if argIdx >= len(args) {
|
||||
return "", fmt.Errorf("insufficient arguments")
|
||||
}
|
||||
|
||||
// Prevent SQL injection via Line Comment Creation
|
||||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||
buf.WriteByte(' ')
|
||||
|
||||
arg := args[argIdx]
|
||||
switch arg := arg.(type) {
|
||||
case nil:
|
||||
str = "null"
|
||||
p = null
|
||||
case int64:
|
||||
str = strconv.FormatInt(arg, 10)
|
||||
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
|
||||
case float64:
|
||||
str = strconv.FormatFloat(arg, 'f', -1, 64)
|
||||
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
|
||||
case bool:
|
||||
str = strconv.FormatBool(arg)
|
||||
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
|
||||
case []byte:
|
||||
str = QuoteBytes(arg)
|
||||
p = QuoteBytes(buf.AvailableBuffer(), arg)
|
||||
case string:
|
||||
str = QuoteString(arg)
|
||||
p = QuoteString(buf.AvailableBuffer(), arg)
|
||||
case time.Time:
|
||||
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||
p = arg.Truncate(time.Microsecond).
|
||||
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||
default:
|
||||
return "", fmt.Errorf("invalid arg type: %T", arg)
|
||||
}
|
||||
argUse[argIdx] = true
|
||||
|
||||
buf.Write(p)
|
||||
|
||||
// Prevent SQL injection via Line Comment Creation
|
||||
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
|
||||
str = " " + str + " "
|
||||
buf.WriteByte(' ')
|
||||
default:
|
||||
return "", fmt.Errorf("invalid Part type: %T", part)
|
||||
}
|
||||
buf.WriteString(str)
|
||||
}
|
||||
|
||||
for i, used := range argUse {
|
||||
@ -82,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) {
|
||||
}
|
||||
|
||||
func NewQuery(sql string) (*Query, error) {
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
query := &Query{}
|
||||
query.init(sql)
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
var sqlLexerPool = &pool[*sqlLexer]{
|
||||
new: func() *sqlLexer {
|
||||
return &sqlLexer{}
|
||||
},
|
||||
reset: func(sl *sqlLexer) bool {
|
||||
*sl = sqlLexer{}
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
func (q *Query) init(sql string) {
|
||||
parts := q.Parts[:0]
|
||||
if parts == nil {
|
||||
// dirty, but fast heuristic to preallocate for ~90% usecases
|
||||
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
|
||||
parts = make([]Part, 0, n)
|
||||
}
|
||||
|
||||
l := sqlLexerPool.get()
|
||||
defer sqlLexerPool.put(l)
|
||||
|
||||
l.src = sql
|
||||
l.stateFn = rawState
|
||||
l.parts = parts
|
||||
|
||||
for l.stateFn != nil {
|
||||
l.stateFn = l.stateFn(l)
|
||||
}
|
||||
|
||||
query := &Query{Parts: l.parts}
|
||||
|
||||
return query, nil
|
||||
q.Parts = l.parts
|
||||
}
|
||||
|
||||
func QuoteString(str string) string {
|
||||
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
|
||||
func QuoteString(dst []byte, str string) []byte {
|
||||
const quote = '\''
|
||||
|
||||
// Preallocate space for the worst case scenario
|
||||
dst = slices.Grow(dst, len(str)*2+2)
|
||||
|
||||
// Add opening quote
|
||||
dst = append(dst, quote)
|
||||
|
||||
// Iterate through the string without allocating
|
||||
for i := 0; i < len(str); i++ {
|
||||
if str[i] == quote {
|
||||
dst = append(dst, quote, quote)
|
||||
} else {
|
||||
dst = append(dst, str[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Add closing quote
|
||||
dst = append(dst, quote)
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
func QuoteBytes(buf []byte) string {
|
||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||
func QuoteBytes(dst, buf []byte) []byte {
|
||||
if len(buf) == 0 {
|
||||
return append(dst, `'\x'`...)
|
||||
}
|
||||
|
||||
// Calculate required length
|
||||
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
|
||||
|
||||
// Ensure dst has enough capacity
|
||||
if cap(dst)-len(dst) < requiredLen {
|
||||
newDst := make([]byte, len(dst), len(dst)+requiredLen)
|
||||
copy(newDst, dst)
|
||||
dst = newDst
|
||||
}
|
||||
|
||||
// Record original length and extend slice
|
||||
origLen := len(dst)
|
||||
dst = dst[:origLen+requiredLen]
|
||||
|
||||
// Add prefix
|
||||
dst[origLen] = '\''
|
||||
dst[origLen+1] = '\\'
|
||||
dst[origLen+2] = 'x'
|
||||
|
||||
// Encode bytes directly into dst
|
||||
hex.Encode(dst[origLen+3:len(dst)-1], buf)
|
||||
|
||||
// Add suffix
|
||||
dst[len(dst)-1] = '\''
|
||||
|
||||
return dst
|
||||
}
|
||||
|
||||
type sqlLexer struct {
|
||||
@ -319,13 +416,45 @@ func multilineCommentState(l *sqlLexer) stateFn {
|
||||
}
|
||||
}
|
||||
|
||||
var queryPool = &pool[*Query]{
|
||||
new: func() *Query {
|
||||
return &Query{}
|
||||
},
|
||||
reset: func(q *Query) bool {
|
||||
n := len(q.Parts)
|
||||
q.Parts = q.Parts[:0]
|
||||
return n < 64 // drop too large queries
|
||||
},
|
||||
}
|
||||
|
||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||
// as necessary. This function is only safe when standard_conforming_strings is
|
||||
// on.
|
||||
func SanitizeSQL(sql string, args ...any) (string, error) {
|
||||
query, err := NewQuery(sql)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
query := queryPool.get()
|
||||
query.init(sql)
|
||||
defer queryPool.put(query)
|
||||
|
||||
return query.Sanitize(args...)
|
||||
}
|
||||
|
||||
type pool[E any] struct {
|
||||
p sync.Pool
|
||||
new func() E
|
||||
reset func(E) bool
|
||||
}
|
||||
|
||||
func (pool *pool[E]) get() E {
|
||||
v, ok := pool.p.Get().(E)
|
||||
if !ok {
|
||||
v = pool.new()
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
func (p *pool[E]) put(v E) {
|
||||
if p.reset(v) {
|
||||
p.p.Put(v)
|
||||
}
|
||||
}
|
||||
|
62
internal/sanitize/sanitize_bench_test.go
Normal file
62
internal/sanitize/sanitize_bench_test.go
Normal file
@ -0,0 +1,62 @@
|
||||
// sanitize_benchmark_test.go
|
||||
package sanitize_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||
)
|
||||
|
||||
var benchmarkSanitizeResult string
|
||||
|
||||
const benchmarkQuery = "" +
|
||||
`SELECT *
|
||||
FROM "water_containers"
|
||||
WHERE NOT "id" = $1 -- int64
|
||||
AND "tags" NOT IN $2 -- nil
|
||||
AND "volume" > $3 -- float64
|
||||
AND "transportable" = $4 -- bool
|
||||
AND position($5 IN "sign") -- bytes
|
||||
AND "label" LIKE $6 -- string
|
||||
AND "created_at" > $7; -- time.Time`
|
||||
|
||||
var benchmarkArgs = []any{
|
||||
int64(12345),
|
||||
nil,
|
||||
float64(500),
|
||||
true,
|
||||
[]byte("8BADF00D"),
|
||||
"kombucha's han'dy awokowa",
|
||||
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
|
||||
func BenchmarkSanitize(b *testing.B) {
|
||||
query, err := sanitize.NewQuery(benchmarkQuery)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to create query: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to sanitize query: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var benchmarkNewSQLResult string
|
||||
|
||||
func BenchmarkSanitizeSQL(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
var err error
|
||||
for i := 0; i < b.N; i++ {
|
||||
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
|
||||
if err != nil {
|
||||
b.Fatalf("failed to sanitize SQL: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
55
internal/sanitize/sanitize_fuzz_test.go
Normal file
55
internal/sanitize/sanitize_fuzz_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package sanitize_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/sanitize"
|
||||
)
|
||||
|
||||
func FuzzQuoteString(f *testing.F) {
|
||||
const prefix = "prefix"
|
||||
f.Add("new\nline")
|
||||
f.Add("sample text")
|
||||
f.Add("sample q'u'o't'e's")
|
||||
f.Add("select 'quoted $42', $1")
|
||||
|
||||
f.Fuzz(func(t *testing.T, input string) {
|
||||
got := string(sanitize.QuoteString([]byte(prefix), input))
|
||||
want := oldQuoteString(input)
|
||||
|
||||
quoted, ok := strings.CutPrefix(got, prefix)
|
||||
if !ok {
|
||||
t.Fatalf("result has no prefix")
|
||||
}
|
||||
|
||||
if want != quoted {
|
||||
t.Errorf("got %q", got)
|
||||
t.Fatalf("want %q", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzQuoteBytes(f *testing.F) {
|
||||
const prefix = "prefix"
|
||||
f.Add([]byte(nil))
|
||||
f.Add([]byte("\n"))
|
||||
f.Add([]byte("sample text"))
|
||||
f.Add([]byte("sample q'u'o't'e's"))
|
||||
f.Add([]byte("select 'quoted $42', $1"))
|
||||
|
||||
f.Fuzz(func(t *testing.T, input []byte) {
|
||||
got := string(sanitize.QuoteBytes([]byte(prefix), input))
|
||||
want := oldQuoteBytes(input)
|
||||
|
||||
quoted, ok := strings.CutPrefix(got, prefix)
|
||||
if !ok {
|
||||
t.Fatalf("result has no prefix")
|
||||
}
|
||||
|
||||
if want != quoted {
|
||||
t.Errorf("got %q", got)
|
||||
t.Fatalf("want %q", want)
|
||||
}
|
||||
})
|
||||
}
|
@ -1,6 +1,8 @@
|
||||
package sanitize_test
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -227,3 +229,55 @@ func TestQuerySanitize(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuoteString(t *testing.T) {
|
||||
tc := func(name, input string) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := string(sanitize.QuoteString(nil, input))
|
||||
want := oldQuoteString(input)
|
||||
|
||||
if got != want {
|
||||
t.Errorf("got: %s", got)
|
||||
t.Fatalf("want: %s", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
tc("empty", "")
|
||||
tc("text", "abcd")
|
||||
tc("with quotes", `one's hat is always a cat`)
|
||||
}
|
||||
|
||||
// This function was used before optimizations.
|
||||
// You should keep for testing purposes - we want to ensure there are no breaking changes.
|
||||
func oldQuoteString(str string) string {
|
||||
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
|
||||
}
|
||||
|
||||
func TestQuoteBytes(t *testing.T) {
|
||||
tc := func(name string, input []byte) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := string(sanitize.QuoteBytes(nil, input))
|
||||
want := oldQuoteBytes(input)
|
||||
|
||||
if got != want {
|
||||
t.Errorf("got: %s", got)
|
||||
t.Fatalf("want: %s", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
tc("nil", nil)
|
||||
tc("empty", []byte{})
|
||||
tc("text", []byte("abcd"))
|
||||
}
|
||||
|
||||
// This function was used before optimizations.
|
||||
// You should keep for testing purposes - we want to ensure there are no breaking changes.
|
||||
func oldQuoteBytes(buf []byte) string {
|
||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ func TestContextWatcherContextCancelled(t *testing.T) {
|
||||
require.True(t, cleanupCalled, "Cleanup func was not called")
|
||||
}
|
||||
|
||||
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
|
||||
func TestContextWatcherUnwatchedBeforeContextCancelled(t *testing.T) {
|
||||
cw := ctxwatch.NewContextWatcher(&testHandler{
|
||||
handleCancel: func(context.Context) {
|
||||
t.Error("cancel func should not have been called")
|
||||
|
@ -12,7 +12,7 @@ type PasswordMessage struct {
|
||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||
func (*PasswordMessage) Frontend() {}
|
||||
|
||||
// Frontend identifies this message as an authentication response.
|
||||
// InitialResponse identifies this message as an authentication response.
|
||||
func (*PasswordMessage) InitialResponse() {}
|
||||
|
||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||
|
@ -1,4 +1,5 @@
|
||||
// Do not edit. Generated from pgtype/int.go.erb
|
||||
// Code generated from pgtype/int.go.erb. DO NOT EDIT.
|
||||
|
||||
package pgtype
|
||||
|
||||
import (
|
||||
|
@ -1,4 +1,5 @@
|
||||
// Do not edit. Generated from pgtype/int_test.go.erb
|
||||
// Code generated from pgtype/int_test.go.erb. DO NOT EDIT.
|
||||
|
||||
package pgtype_test
|
||||
|
||||
import (
|
||||
|
@ -1,3 +1,5 @@
|
||||
// Code generated from pgtype/integration_benchmark_test.go.erb. DO NOT EDIT.
|
||||
|
||||
package pgtype_test
|
||||
|
||||
import (
|
||||
|
@ -25,7 +25,7 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go
|
||||
rows, _ := conn.Query(
|
||||
ctx,
|
||||
`select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`,
|
||||
[]any{pgx.QueryResultFormats{<%= format_code %>}},
|
||||
pgx.QueryResultFormats{<%= format_code %>},
|
||||
)
|
||||
_, err := pgx.ForEachRow(rows, []any{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, func() error { return nil })
|
||||
if err != nil {
|
||||
@ -49,7 +49,7 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_Int4Array
|
||||
rows, _ := conn.Query(
|
||||
ctx,
|
||||
`select array_agg(n) from generate_series(1, <%= array_size %>) n`,
|
||||
[]any{pgx.QueryResultFormats{<%= format_code %>}},
|
||||
pgx.QueryResultFormats{<%= format_code %>},
|
||||
)
|
||||
_, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil })
|
||||
if err != nil {
|
||||
|
@ -161,6 +161,10 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP
|
||||
//
|
||||
// https://github.com/jackc/pgx/issues/2146
|
||||
func isSQLScanner(v any) bool {
|
||||
if _, is := v.(sql.Scanner); is {
|
||||
return true
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(v)
|
||||
for val.Kind() == reflect.Ptr {
|
||||
if _, ok := val.Interface().(sql.Scanner); ok {
|
||||
@ -212,7 +216,12 @@ func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
|
||||
return fmt.Errorf("cannot scan NULL into %T", dst)
|
||||
}
|
||||
|
||||
elem := reflect.ValueOf(dst).Elem()
|
||||
v := reflect.ValueOf(dst)
|
||||
if v.Kind() != reflect.Pointer || v.IsNil() {
|
||||
return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst)
|
||||
}
|
||||
|
||||
elem := v.Elem()
|
||||
elem.Set(reflect.Zero(elem.Type()))
|
||||
|
||||
return s.unmarshal(src, dst)
|
||||
|
@ -48,6 +48,7 @@ func TestJSONCodec(t *testing.T) {
|
||||
Age int `json:"age"`
|
||||
}
|
||||
|
||||
var str string
|
||||
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "json", []pgxtest.ValueRoundTripTest{
|
||||
{nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))},
|
||||
{map[string]any(nil), new(*string), isExpectedEq((*string)(nil))},
|
||||
@ -65,6 +66,9 @@ func TestJSONCodec(t *testing.T) {
|
||||
{Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))},
|
||||
// Test driver.Scanner is used before json.Unmarshaler (https://github.com/jackc/pgx/issues/2146)
|
||||
{Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))},
|
||||
|
||||
// Test driver.Scanner without pointer receiver (https://github.com/jackc/pgx/issues/2204)
|
||||
{NonPointerJSONScanner{V: stringPtr("{}")}, NonPointerJSONScanner{V: &str}, func(a any) bool { return str == "{}" }},
|
||||
})
|
||||
|
||||
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
||||
@ -136,6 +140,27 @@ func (i Issue2146) Value() (driver.Value, error) {
|
||||
return string(b), err
|
||||
}
|
||||
|
||||
type NonPointerJSONScanner struct {
|
||||
V *string
|
||||
}
|
||||
|
||||
func (i NonPointerJSONScanner) Scan(src any) error {
|
||||
switch c := src.(type) {
|
||||
case string:
|
||||
*i.V = c
|
||||
case []byte:
|
||||
*i.V = string(c)
|
||||
default:
|
||||
return errors.New("unknown source type")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i NonPointerJSONScanner) Value() (driver.Value, error) {
|
||||
return i.V, nil
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648
|
||||
func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
@ -267,7 +292,8 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
|
||||
Unmarshal: func(data []byte, v any) error {
|
||||
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
|
||||
},
|
||||
}})
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
|
||||
@ -278,3 +304,20 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
|
||||
}},
|
||||
})
|
||||
}
|
||||
|
||||
func TestJSONCodecScanToNonPointerValues(t *testing.T) {
|
||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
n := 44
|
||||
err := conn.QueryRow(ctx, "select '42'::jsonb").Scan(n)
|
||||
require.Error(t, err)
|
||||
|
||||
var i *int
|
||||
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(i)
|
||||
require.Error(t, err)
|
||||
|
||||
m := 0
|
||||
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(&m)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 42, m)
|
||||
})
|
||||
}
|
||||
|
@ -415,6 +415,10 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
|
||||
|
||||
// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
|
||||
func getSQLScanner(target any) sql.Scanner {
|
||||
if sc, is := target.(sql.Scanner); is {
|
||||
return sc
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(target)
|
||||
for val.Kind() == reflect.Ptr {
|
||||
if _, ok := val.Interface().(sql.Scanner); ok {
|
||||
|
@ -1,4 +1,5 @@
|
||||
// Do not edit. Generated from pgtype/zeronull/int.go.erb
|
||||
// Code generated from pgtype/zeronull/int.go.erb. DO NOT EDIT.
|
||||
|
||||
package zeronull
|
||||
|
||||
import (
|
||||
|
@ -1,4 +1,5 @@
|
||||
// Do not edit. Generated from pgtype/zeronull/int_test.go.erb
|
||||
// Code generated from pgtype/zeronull/int_test.go.erb. DO NOT EDIT.
|
||||
|
||||
package zeronull_test
|
||||
|
||||
import (
|
||||
|
@ -82,3 +82,10 @@ func (s *Stat) MaxLifetimeDestroyCount() int64 {
|
||||
func (s *Stat) MaxIdleDestroyCount() int64 {
|
||||
return s.idleDestroyCount
|
||||
}
|
||||
|
||||
// EmptyAcquireWaitTime returns the cumulative time waited for successful acquires
|
||||
// from the pool for a resource to be released or constructed because the pool was
|
||||
// empty.
|
||||
func (s *Stat) EmptyAcquireWaitTime() time.Duration {
|
||||
return s.s.EmptyAcquireWaitTime()
|
||||
}
|
||||
|
@ -106,12 +106,12 @@ func main() {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
writeEncryptedPrivateKey("pgx_sslcert.key", clientCertPrivKey, "certpw")
|
||||
err = writeEncryptedPrivateKey("pgx_sslcert.key", clientCertPrivKey, "certpw")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
writeCertificate("pgx_sslcert.crt", clientBytes)
|
||||
err = writeCertificate("pgx_sslcert.crt", clientBytes)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
20
tx.go
20
tx.go
@ -3,7 +3,6 @@ package pgx
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@ -48,6 +47,8 @@ type TxOptions struct {
|
||||
// BeginQuery is the SQL query that will be executed to begin the transaction. This allows using non-standard syntax
|
||||
// such as BEGIN PRIORITY HIGH with CockroachDB. If set this will override the other settings.
|
||||
BeginQuery string
|
||||
// CommitQuery is the SQL query that will be executed to commit the transaction.
|
||||
CommitQuery string
|
||||
}
|
||||
|
||||
var emptyTxOptions TxOptions
|
||||
@ -101,11 +102,14 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) {
|
||||
if err != nil {
|
||||
// begin should never fail unless there is an underlying connection issue or
|
||||
// a context timeout. In either case, the connection is possibly broken.
|
||||
c.die(errors.New("failed to begin transaction"))
|
||||
c.die()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dbTx{conn: c}, nil
|
||||
return &dbTx{
|
||||
conn: c,
|
||||
commitQuery: txOptions.CommitQuery,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Tx represents a database transaction.
|
||||
@ -154,6 +158,7 @@ type dbTx struct {
|
||||
conn *Conn
|
||||
savepointNum int64
|
||||
closed bool
|
||||
commitQuery string
|
||||
}
|
||||
|
||||
// Begin starts a pseudo nested transaction implemented with a savepoint.
|
||||
@ -177,7 +182,12 @@ func (tx *dbTx) Commit(ctx context.Context) error {
|
||||
return ErrTxClosed
|
||||
}
|
||||
|
||||
commandTag, err := tx.conn.Exec(ctx, "commit")
|
||||
commandSQL := "commit"
|
||||
if tx.commitQuery != "" {
|
||||
commandSQL = tx.commitQuery
|
||||
}
|
||||
|
||||
commandTag, err := tx.conn.Exec(ctx, commandSQL)
|
||||
tx.closed = true
|
||||
if err != nil {
|
||||
if tx.conn.PgConn().TxStatus() != 'I' {
|
||||
@ -205,7 +215,7 @@ func (tx *dbTx) Rollback(ctx context.Context) error {
|
||||
tx.closed = true
|
||||
if err != nil {
|
||||
// A rollback failure leaves the connection in an undefined state
|
||||
tx.conn.die(fmt.Errorf("rollback failed: %w", err))
|
||||
tx.conn.die()
|
||||
return err
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user