Merge branch 'jackc:master' into master

This commit is contained in:
Cedric Le Roux 2025-01-09 20:23:28 -08:00 committed by GitHub
commit 4fa324bce8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 515 additions and 75 deletions

View File

@ -1,3 +1,15 @@
# 5.7.2 (December 21, 2024)
* Fix prepared statement already exists on batch prepare failure
* Add commit query to tx options (Lucas Hild)
* Fix pgtype.Timestamp json unmarshal (Shean de Montigny-Desautels)
* Add message body size limits in frontend and backend (zene)
* Add xid8 type
* Ensure planning encodes and scans cannot infinitely recurse
* Implement pgtype.UUID.String() (Konstantin Grachev)
* Switch from ExecParams to Exec in ValidateConnectTargetSessionAttrs functions (Alexander Rumyantsev)
* Update golang.org/x/crypto
# 5.7.1 (September 10, 2024)
* Fix data race in tracelog.TraceLog

View File

@ -84,7 +84,7 @@ It is also possible to use the `database/sql` interface and convert a connection
## Testing
See CONTRIBUTING.md for setup instructions.
See [CONTRIBUTING.md](./CONTRIBUTING.md) for setup instructions.
## Architecture
@ -126,7 +126,7 @@ pgerrcode contains constants for the PostgreSQL error codes.
## Adapters for 3rd Party Tracers
* [https://github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer)
* [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer)
## Adapters for 3rd Party Loggers
@ -156,7 +156,7 @@ Library for scanning data from a database into Go structs and more.
A carefully designed SQL client for making using SQL easier,
more productive, and less error-prone on Golang.
### [https://github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
### [github.com/otan/gopgkrb5](https://github.com/otan/gopgkrb5)
Adds GSSAPI / Kerberos authentication support.
@ -169,6 +169,10 @@ Explicit data mapping and scanning library for Go structs and slices.
Type safe and flexible package for scanning database data into Go types.
Supports, structs, maps, slices and custom mapping functions.
### [https://github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx)
### [github.com/z0ne-dev/mgx](https://github.com/z0ne-dev/mgx)
Code first migration library for native pgx (no database/sql abstraction).
### [github.com/amirsalarsafaei/sqlc-pgx-monitoring](https://github.com/amirsalarsafaei/sqlc-pgx-monitoring)
A database monitoring/metrics library for pgx and sqlc. Trace, log and monitor your sqlc query performance using OpenTelemetry.

View File

@ -2,7 +2,7 @@ require "erb"
rule '.go' => '.go.erb' do |task|
erb = ERB.new(File.read(task.source))
File.write(task.name, "// Do not edit. Generated from #{task.source}\n" + erb.result(binding))
File.write(task.name, "// Code generated from #{task.source}. DO NOT EDIT.\n\n" + erb.result(binding))
sh "goimports", "-w", task.name
end

10
conn.go
View File

@ -444,7 +444,7 @@ func (c *Conn) IsClosed() bool {
return c.pgConn.IsClosed()
}
func (c *Conn) die(err error) {
func (c *Conn) die() {
if c.IsClosed() {
return
}
@ -612,14 +612,6 @@ func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription
return result.CommandTag, result.Err
}
type unknownArgumentTypeQueryExecModeExecError struct {
arg any
}
func (e *unknownArgumentTypeQueryExecModeExecError) Error() string {
return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg)
}
func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) {
err := c.eqb.Build(c.typeMap, nil, args)
if err != nil {

View File

@ -1417,5 +1417,4 @@ func TestErrNoRows(t *testing.T) {
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())
require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
require.ErrorIs(t, pgx.ErrNoRows, pgx.ErrNoRows, "sql.ErrNowRows must match pgx.ErrNoRows")
}

View File

@ -161,7 +161,7 @@ type derivedTypeInfo struct {
// The result of this call can be passed into RegisterTypes to complete the process.
func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Type, error) {
m := c.TypeMap()
if typeNames == nil || len(typeNames) == 0 {
if len(typeNames) == 0 {
return nil, fmt.Errorf("No type names were supplied.")
}
@ -169,13 +169,7 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ
// the SQL not support recent structures such as multirange
serverVersion, _ := serverVersion(c)
sql := buildLoadDerivedTypesSQL(serverVersion, typeNames)
var rows Rows
var err error
if typeNames == nil {
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol)
} else {
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
}
rows, err := c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
if err != nil {
return nil, fmt.Errorf("While generating load types query: %w", err)
}
@ -232,15 +226,15 @@ func (c *Conn) LoadTypes(ctx context.Context, typeNames []string) ([]*pgtype.Typ
default:
return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName)
}
if type_ != nil {
m.RegisterType(type_)
if ti.NspName != "" {
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}
m.RegisterType(nspType)
result = append(result, nspType)
}
result = append(result, type_)
// the type_ is imposible to be null
m.RegisterType(type_)
if ti.NspName != "" {
nspType := &pgtype.Type{Name: ti.NspName + "." + type_.Name, OID: type_.OID, Codec: type_.Codec}
m.RegisterType(nspType)
result = append(result, nspType)
}
result = append(result, type_)
}
return result, nil
}

View File

@ -0,0 +1,60 @@
#!/usr/bin/env bash
current_branch=$(git rev-parse --abbrev-ref HEAD)
if [ "$current_branch" == "HEAD" ]; then
current_branch=$(git rev-parse HEAD)
fi
restore_branch() {
echo "Restoring original branch/commit: $current_branch"
git checkout "$current_branch"
}
trap restore_branch EXIT
# Check if there are uncommitted changes
if ! git diff --quiet || ! git diff --cached --quiet; then
echo "There are uncommitted changes. Please commit or stash them before running this script."
exit 1
fi
# Ensure that at least one commit argument is passed
if [ "$#" -lt 1 ]; then
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
exit 1
fi
commits=("$@")
benchmarks_dir=benchmarks
if ! mkdir -p "${benchmarks_dir}"; then
echo "Unable to create dir for benchmarks data"
exit 1
fi
# Benchmark results
bench_files=()
# Run benchmark for each listed commit
for i in "${!commits[@]}"; do
commit="${commits[i]}"
git checkout "$commit" || {
echo "Failed to checkout $commit"
exit 1
}
# Sanitized commmit message
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')
# Benchmark data will go there
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"
if ! go test -bench=. -count=10 >"$bench_file"; then
echo "Benchmarking failed for commit $commit"
exit 1
fi
bench_files+=("$bench_file")
done
# go install golang.org/x/perf/cmd/benchstat[@latest]
benchstat "${bench_files[@]}"

View File

@ -4,8 +4,10 @@ import (
"bytes"
"encoding/hex"
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
)
@ -24,18 +26,33 @@ type Query struct {
// https://github.com/jackc/pgx/issues/1380
const replacementcharacterwidth = 3
const maxBufSize = 16384 // 16 Ki
var bufPool = &pool[*bytes.Buffer]{
new: func() *bytes.Buffer {
return &bytes.Buffer{}
},
reset: func(b *bytes.Buffer) bool {
n := b.Len()
b.Reset()
return n < maxBufSize
},
}
var null = []byte("null")
func (q *Query) Sanitize(args ...any) (string, error) {
argUse := make([]bool, len(args))
buf := &bytes.Buffer{}
buf := bufPool.get()
defer bufPool.put(buf)
for _, part := range q.Parts {
var str string
switch part := part.(type) {
case string:
str = part
buf.WriteString(part)
case int:
argIdx := part - 1
var p []byte
if argIdx < 0 {
return "", fmt.Errorf("first sql argument must be > 0")
}
@ -43,34 +60,41 @@ func (q *Query) Sanitize(args ...any) (string, error) {
if argIdx >= len(args) {
return "", fmt.Errorf("insufficient arguments")
}
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
buf.WriteByte(' ')
arg := args[argIdx]
switch arg := arg.(type) {
case nil:
str = "null"
p = null
case int64:
str = strconv.FormatInt(arg, 10)
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
case float64:
str = strconv.FormatFloat(arg, 'f', -1, 64)
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
case bool:
str = strconv.FormatBool(arg)
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
case []byte:
str = QuoteBytes(arg)
p = QuoteBytes(buf.AvailableBuffer(), arg)
case string:
str = QuoteString(arg)
p = QuoteString(buf.AvailableBuffer(), arg)
case time.Time:
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
p = arg.Truncate(time.Microsecond).
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
default:
return "", fmt.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true
buf.Write(p)
// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
str = " " + str + " "
buf.WriteByte(' ')
default:
return "", fmt.Errorf("invalid Part type: %T", part)
}
buf.WriteString(str)
}
for i, used := range argUse {
@ -82,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) {
}
func NewQuery(sql string) (*Query, error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
query := &Query{}
query.init(sql)
return query, nil
}
var sqlLexerPool = &pool[*sqlLexer]{
new: func() *sqlLexer {
return &sqlLexer{}
},
reset: func(sl *sqlLexer) bool {
*sl = sqlLexer{}
return true
},
}
func (q *Query) init(sql string) {
parts := q.Parts[:0]
if parts == nil {
// dirty, but fast heuristic to preallocate for ~90% usecases
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
parts = make([]Part, 0, n)
}
l := sqlLexerPool.get()
defer sqlLexerPool.put(l)
l.src = sql
l.stateFn = rawState
l.parts = parts
for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
query := &Query{Parts: l.parts}
return query, nil
q.Parts = l.parts
}
func QuoteString(str string) string {
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
func QuoteString(dst []byte, str string) []byte {
const quote = '\''
// Preallocate space for the worst case scenario
dst = slices.Grow(dst, len(str)*2+2)
// Add opening quote
dst = append(dst, quote)
// Iterate through the string without allocating
for i := 0; i < len(str); i++ {
if str[i] == quote {
dst = append(dst, quote, quote)
} else {
dst = append(dst, str[i])
}
}
// Add closing quote
dst = append(dst, quote)
return dst
}
func QuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
func QuoteBytes(dst, buf []byte) []byte {
if len(buf) == 0 {
return append(dst, `'\x'`...)
}
// Calculate required length
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1
// Ensure dst has enough capacity
if cap(dst)-len(dst) < requiredLen {
newDst := make([]byte, len(dst), len(dst)+requiredLen)
copy(newDst, dst)
dst = newDst
}
// Record original length and extend slice
origLen := len(dst)
dst = dst[:origLen+requiredLen]
// Add prefix
dst[origLen] = '\''
dst[origLen+1] = '\\'
dst[origLen+2] = 'x'
// Encode bytes directly into dst
hex.Encode(dst[origLen+3:len(dst)-1], buf)
// Add suffix
dst[len(dst)-1] = '\''
return dst
}
type sqlLexer struct {
@ -319,13 +416,45 @@ func multilineCommentState(l *sqlLexer) stateFn {
}
}
var queryPool = &pool[*Query]{
new: func() *Query {
return &Query{}
},
reset: func(q *Query) bool {
n := len(q.Parts)
q.Parts = q.Parts[:0]
return n < 64 // drop too large queries
},
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func SanitizeSQL(sql string, args ...any) (string, error) {
query, err := NewQuery(sql)
if err != nil {
return "", err
}
query := queryPool.get()
query.init(sql)
defer queryPool.put(query)
return query.Sanitize(args...)
}
type pool[E any] struct {
p sync.Pool
new func() E
reset func(E) bool
}
func (pool *pool[E]) get() E {
v, ok := pool.p.Get().(E)
if !ok {
v = pool.new()
}
return v
}
func (p *pool[E]) put(v E) {
if p.reset(v) {
p.p.Put(v)
}
}

View File

@ -0,0 +1,62 @@
// sanitize_benchmark_test.go
package sanitize_test
import (
"testing"
"time"
"github.com/jackc/pgx/v5/internal/sanitize"
)
var benchmarkSanitizeResult string
const benchmarkQuery = "" +
`SELECT *
FROM "water_containers"
WHERE NOT "id" = $1 -- int64
AND "tags" NOT IN $2 -- nil
AND "volume" > $3 -- float64
AND "transportable" = $4 -- bool
AND position($5 IN "sign") -- bytes
AND "label" LIKE $6 -- string
AND "created_at" > $7; -- time.Time`
var benchmarkArgs = []any{
int64(12345),
nil,
float64(500),
true,
[]byte("8BADF00D"),
"kombucha's han'dy awokowa",
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
}
func BenchmarkSanitize(b *testing.B) {
query, err := sanitize.NewQuery(benchmarkQuery)
if err != nil {
b.Fatalf("failed to create query: %v", err)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize query: %v", err)
}
}
}
var benchmarkNewSQLResult string
func BenchmarkSanitizeSQL(b *testing.B) {
b.ReportAllocs()
var err error
for i := 0; i < b.N; i++ {
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize SQL: %v", err)
}
}
}

View File

@ -0,0 +1,55 @@
package sanitize_test
import (
"strings"
"testing"
"github.com/jackc/pgx/v5/internal/sanitize"
)
func FuzzQuoteString(f *testing.F) {
const prefix = "prefix"
f.Add("new\nline")
f.Add("sample text")
f.Add("sample q'u'o't'e's")
f.Add("select 'quoted $42', $1")
f.Fuzz(func(t *testing.T, input string) {
got := string(sanitize.QuoteString([]byte(prefix), input))
want := oldQuoteString(input)
quoted, ok := strings.CutPrefix(got, prefix)
if !ok {
t.Fatalf("result has no prefix")
}
if want != quoted {
t.Errorf("got %q", got)
t.Fatalf("want %q", want)
}
})
}
func FuzzQuoteBytes(f *testing.F) {
const prefix = "prefix"
f.Add([]byte(nil))
f.Add([]byte("\n"))
f.Add([]byte("sample text"))
f.Add([]byte("sample q'u'o't'e's"))
f.Add([]byte("select 'quoted $42', $1"))
f.Fuzz(func(t *testing.T, input []byte) {
got := string(sanitize.QuoteBytes([]byte(prefix), input))
want := oldQuoteBytes(input)
quoted, ok := strings.CutPrefix(got, prefix)
if !ok {
t.Fatalf("result has no prefix")
}
if want != quoted {
t.Errorf("got %q", got)
t.Fatalf("want %q", want)
}
})
}

View File

@ -1,6 +1,8 @@
package sanitize_test
import (
"encoding/hex"
"strings"
"testing"
"time"
@ -227,3 +229,55 @@ func TestQuerySanitize(t *testing.T) {
}
}
}
func TestQuoteString(t *testing.T) {
tc := func(name, input string) {
t.Run(name, func(t *testing.T) {
t.Parallel()
got := string(sanitize.QuoteString(nil, input))
want := oldQuoteString(input)
if got != want {
t.Errorf("got: %s", got)
t.Fatalf("want: %s", want)
}
})
}
tc("empty", "")
tc("text", "abcd")
tc("with quotes", `one's hat is always a cat`)
}
// This function was used before optimizations.
// You should keep for testing purposes - we want to ensure there are no breaking changes.
func oldQuoteString(str string) string {
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
}
func TestQuoteBytes(t *testing.T) {
tc := func(name string, input []byte) {
t.Run(name, func(t *testing.T) {
t.Parallel()
got := string(sanitize.QuoteBytes(nil, input))
want := oldQuoteBytes(input)
if got != want {
t.Errorf("got: %s", got)
t.Fatalf("want: %s", want)
}
})
}
tc("nil", nil)
tc("empty", []byte{})
tc("text", []byte("abcd"))
}
// This function was used before optimizations.
// You should keep for testing purposes - we want to ensure there are no breaking changes.
func oldQuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
}

View File

@ -49,7 +49,7 @@ func TestContextWatcherContextCancelled(t *testing.T) {
require.True(t, cleanupCalled, "Cleanup func was not called")
}
func TestContextWatcherUnwatchdBeforeContextCancelled(t *testing.T) {
func TestContextWatcherUnwatchedBeforeContextCancelled(t *testing.T) {
cw := ctxwatch.NewContextWatcher(&testHandler{
handleCancel: func(context.Context) {
t.Error("cancel func should not have been called")

View File

@ -12,7 +12,7 @@ type PasswordMessage struct {
// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*PasswordMessage) Frontend() {}
// Frontend identifies this message as an authentication response.
// InitialResponse identifies this message as an authentication response.
func (*PasswordMessage) InitialResponse() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message

View File

@ -1,4 +1,5 @@
// Do not edit. Generated from pgtype/int.go.erb
// Code generated from pgtype/int.go.erb. DO NOT EDIT.
package pgtype
import (

View File

@ -1,4 +1,5 @@
// Do not edit. Generated from pgtype/int_test.go.erb
// Code generated from pgtype/int_test.go.erb. DO NOT EDIT.
package pgtype_test
import (

View File

@ -1,3 +1,5 @@
// Code generated from pgtype/integration_benchmark_test.go.erb. DO NOT EDIT.
package pgtype_test
import (

View File

@ -25,7 +25,7 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_<%= pg_type %>_to_Go_<%= go
rows, _ := conn.Query(
ctx,
`select <% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>n::<%= pg_type %> + <%= col_idx%><% end %> from generate_series(1, <%= rows %>) n`,
[]any{pgx.QueryResultFormats{<%= format_code %>}},
pgx.QueryResultFormats{<%= format_code %>},
)
_, err := pgx.ForEachRow(rows, []any{<% columns.times do |col_idx| %><% if col_idx != 0 %>, <% end %>&v[<%= col_idx%>]<% end %>}, func() error { return nil })
if err != nil {
@ -49,7 +49,7 @@ func BenchmarkQuery<%= format_name %>FormatDecode_PG_Int4Array_With_Go_Int4Array
rows, _ := conn.Query(
ctx,
`select array_agg(n) from generate_series(1, <%= array_size %>) n`,
[]any{pgx.QueryResultFormats{<%= format_code %>}},
pgx.QueryResultFormats{<%= format_code %>},
)
_, err := pgx.ForEachRow(rows, []any{&v}, func() error { return nil })
if err != nil {

View File

@ -161,6 +161,10 @@ func (c *JSONCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanP
//
// https://github.com/jackc/pgx/issues/2146
func isSQLScanner(v any) bool {
if _, is := v.(sql.Scanner); is {
return true
}
val := reflect.ValueOf(v)
for val.Kind() == reflect.Ptr {
if _, ok := val.Interface().(sql.Scanner); ok {
@ -212,7 +216,12 @@ func (s *scanPlanJSONToJSONUnmarshal) Scan(src []byte, dst any) error {
return fmt.Errorf("cannot scan NULL into %T", dst)
}
elem := reflect.ValueOf(dst).Elem()
v := reflect.ValueOf(dst)
if v.Kind() != reflect.Pointer || v.IsNil() {
return fmt.Errorf("cannot scan into non-pointer or nil destinations %T", dst)
}
elem := v.Elem()
elem.Set(reflect.Zero(elem.Type()))
return s.unmarshal(src, dst)

View File

@ -48,6 +48,7 @@ func TestJSONCodec(t *testing.T) {
Age int `json:"age"`
}
var str string
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "json", []pgxtest.ValueRoundTripTest{
{nil, new(*jsonStruct), isExpectedEq((*jsonStruct)(nil))},
{map[string]any(nil), new(*string), isExpectedEq((*string)(nil))},
@ -65,6 +66,9 @@ func TestJSONCodec(t *testing.T) {
{Issue1805(7), new(Issue1805), isExpectedEq(Issue1805(7))},
// Test driver.Scanner is used before json.Unmarshaler (https://github.com/jackc/pgx/issues/2146)
{Issue2146(7), new(*Issue2146), isPtrExpectedEq(Issue2146(7))},
// Test driver.Scanner without pointer receiver (https://github.com/jackc/pgx/issues/2204)
{NonPointerJSONScanner{V: stringPtr("{}")}, NonPointerJSONScanner{V: &str}, func(a any) bool { return str == "{}" }},
})
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
@ -136,6 +140,27 @@ func (i Issue2146) Value() (driver.Value, error) {
return string(b), err
}
type NonPointerJSONScanner struct {
V *string
}
func (i NonPointerJSONScanner) Scan(src any) error {
switch c := src.(type) {
case string:
*i.V = c
case []byte:
*i.V = string(c)
default:
return errors.New("unknown source type")
}
return nil
}
func (i NonPointerJSONScanner) Value() (driver.Value, error) {
return i.V, nil
}
// https://github.com/jackc/pgx/issues/1273#issuecomment-1221414648
func TestJSONCodecUnmarshalSQLNull(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
@ -267,7 +292,8 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
Unmarshal: func(data []byte, v any) error {
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
},
}})
},
})
}
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "json", []pgxtest.ValueRoundTripTest{
@ -278,3 +304,20 @@ func TestJSONCodecCustomMarshal(t *testing.T) {
}},
})
}
func TestJSONCodecScanToNonPointerValues(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
n := 44
err := conn.QueryRow(ctx, "select '42'::jsonb").Scan(n)
require.Error(t, err)
var i *int
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(i)
require.Error(t, err)
m := 0
err = conn.QueryRow(ctx, "select '42'::jsonb").Scan(&m)
require.NoError(t, err)
require.Equal(t, 42, m)
})
}

View File

@ -415,6 +415,10 @@ func (plan *scanPlanSQLScanner) Scan(src []byte, dst any) error {
// we don't know if the target is a sql.Scanner or a pointer on a sql.Scanner, so we need to check recursively
func getSQLScanner(target any) sql.Scanner {
if sc, is := target.(sql.Scanner); is {
return sc
}
val := reflect.ValueOf(target)
for val.Kind() == reflect.Ptr {
if _, ok := val.Interface().(sql.Scanner); ok {

View File

@ -1,4 +1,5 @@
// Do not edit. Generated from pgtype/zeronull/int.go.erb
// Code generated from pgtype/zeronull/int.go.erb. DO NOT EDIT.
package zeronull
import (

View File

@ -1,4 +1,5 @@
// Do not edit. Generated from pgtype/zeronull/int_test.go.erb
// Code generated from pgtype/zeronull/int_test.go.erb. DO NOT EDIT.
package zeronull_test
import (

View File

@ -82,3 +82,10 @@ func (s *Stat) MaxLifetimeDestroyCount() int64 {
func (s *Stat) MaxIdleDestroyCount() int64 {
return s.idleDestroyCount
}
// EmptyAcquireWaitTime returns the cumulative time waited for successful acquires
// from the pool for a resource to be released or constructed because the pool was
// empty.
func (s *Stat) EmptyAcquireWaitTime() time.Duration {
return s.s.EmptyAcquireWaitTime()
}

View File

@ -106,12 +106,12 @@ func main() {
panic(err)
}
writeEncryptedPrivateKey("pgx_sslcert.key", clientCertPrivKey, "certpw")
err = writeEncryptedPrivateKey("pgx_sslcert.key", clientCertPrivKey, "certpw")
if err != nil {
panic(err)
}
writeCertificate("pgx_sslcert.crt", clientBytes)
err = writeCertificate("pgx_sslcert.crt", clientBytes)
if err != nil {
panic(err)
}

20
tx.go
View File

@ -3,7 +3,6 @@ package pgx
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
@ -48,6 +47,8 @@ type TxOptions struct {
// BeginQuery is the SQL query that will be executed to begin the transaction. This allows using non-standard syntax
// such as BEGIN PRIORITY HIGH with CockroachDB. If set this will override the other settings.
BeginQuery string
// CommitQuery is the SQL query that will be executed to commit the transaction.
CommitQuery string
}
var emptyTxOptions TxOptions
@ -101,11 +102,14 @@ func (c *Conn) BeginTx(ctx context.Context, txOptions TxOptions) (Tx, error) {
if err != nil {
// begin should never fail unless there is an underlying connection issue or
// a context timeout. In either case, the connection is possibly broken.
c.die(errors.New("failed to begin transaction"))
c.die()
return nil, err
}
return &dbTx{conn: c}, nil
return &dbTx{
conn: c,
commitQuery: txOptions.CommitQuery,
}, nil
}
// Tx represents a database transaction.
@ -154,6 +158,7 @@ type dbTx struct {
conn *Conn
savepointNum int64
closed bool
commitQuery string
}
// Begin starts a pseudo nested transaction implemented with a savepoint.
@ -177,7 +182,12 @@ func (tx *dbTx) Commit(ctx context.Context) error {
return ErrTxClosed
}
commandTag, err := tx.conn.Exec(ctx, "commit")
commandSQL := "commit"
if tx.commitQuery != "" {
commandSQL = tx.commitQuery
}
commandTag, err := tx.conn.Exec(ctx, commandSQL)
tx.closed = true
if err != nil {
if tx.conn.PgConn().TxStatus() != 'I' {
@ -205,7 +215,7 @@ func (tx *dbTx) Rollback(ctx context.Context) error {
tx.closed = true
if err != nil {
// A rollback failure leaves the connection in an undefined state
tx.conn.die(fmt.Errorf("rollback failed: %w", err))
tx.conn.die()
return err
}