Compare commits

..

No commits in common. "master" and "v5.7.5" have entirely different histories.

88 changed files with 369 additions and 930 deletions

View File

@ -1,21 +0,0 @@
# See for configurations: https://golangci-lint.run/usage/configuration/
version: 2
# See: https://golangci-lint.run/usage/formatters/
formatters:
default: none
enable:
- gofmt # https://pkg.go.dev/cmd/gofmt
- gofumpt # https://github.com/mvdan/gofumpt
settings:
gofmt:
simplify: true # Simplify code: gofmt with `-s` option.
gofumpt:
# Module path which contains the source code being formatted.
# Default: ""
module-path: github.com/jackc/pgx/v5 # Should match with module in go.mod
# Choose whether to use the extra rules.
# Default: false
extra-rules: true

View File

@ -127,7 +127,6 @@ pgerrcode contains constants for the PostgreSQL error codes.
## Adapters for 3rd Party Tracers ## Adapters for 3rd Party Tracers
* [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer) * [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer)
* [github.com/exaring/otelpgx](https://github.com/exaring/otelpgx)
## Adapters for 3rd Party Loggers ## Adapters for 3rd Party Loggers
@ -185,7 +184,3 @@ Simple Golang implementation for transactional outbox pattern for PostgreSQL usi
### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy) ### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy)
Simplifies working with the pgx library, providing convenient scanning of nested structures. Simplifies working with the pgx library, providing convenient scanning of nested structures.
## [https://github.com/KoNekoD/pgx-colon-query-rewriter](https://github.com/KoNekoD/pgx-colon-query-rewriter)
Implementation of the pgx query rewriter to use ':' instead of '@' in named query parameters.

View File

@ -43,10 +43,6 @@ func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
} }
// Exec sets fn to be called when the response to qq is received. // Exec sets fn to be called when the response to qq is received.
//
// Note: for simple batch insert uses where it is not required to handle
// each potential error individually, it's sufficient to not set any callbacks,
// and just handle the return value of BatchResults.Close.
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) { func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
qq.Fn = func(br BatchResults) error { qq.Fn = func(br BatchResults) error {
ct, err := br.Exec() ct, err := br.Exec()
@ -87,7 +83,7 @@ func (b *Batch) Len() int {
type BatchResults interface { type BatchResults interface {
// Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer // Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer
// calling Exec on the QueuedQuery, or just calling Close. // calling Exec on the QueuedQuery.
Exec() (pgconn.CommandTag, error) Exec() (pgconn.CommandTag, error)
// Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer // Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer
@ -102,9 +98,6 @@ type BatchResults interface {
// QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an // QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an
// error or the batch encounters an error subsequent callback functions will not be called. // error or the batch encounters an error subsequent callback functions will not be called.
// //
// For simple batch inserts inside a transaction or similar queries, it's sufficient to not set any callbacks,
// and just handle the return value of Close.
//
// Close must be called before the underlying connection can be used again. Any error that occurred during a batch // Close must be called before the underlying connection can be used again. Any error that occurred during a batch
// operation may have made it impossible to resyncronize the connection with the server. In this case the underlying // operation may have made it impossible to resyncronize the connection with the server. In this case the underlying
// connection will have been closed. // connection will have been closed.
@ -214,6 +207,7 @@ func (br *batchResults) Query() (Rows, error) {
func (br *batchResults) QueryRow() Row { func (br *batchResults) QueryRow() Row {
rows, _ := br.Query() rows, _ := br.Query()
return (*connRow)(rows.(*baseRows)) return (*connRow)(rows.(*baseRows))
} }
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to // Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
@ -226,8 +220,6 @@ func (br *batchResults) Close() error {
} }
br.endTraced = true br.endTraced = true
} }
invalidateCachesOnBatchResultsError(br.conn, br.b, br.err)
}() }()
if br.err != nil { if br.err != nil {
@ -386,6 +378,7 @@ func (br *pipelineBatchResults) Query() (Rows, error) {
func (br *pipelineBatchResults) QueryRow() Row { func (br *pipelineBatchResults) QueryRow() Row {
rows, _ := br.Query() rows, _ := br.Query()
return (*connRow)(rows.(*baseRows)) return (*connRow)(rows.(*baseRows))
} }
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to // Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
@ -398,8 +391,6 @@ func (br *pipelineBatchResults) Close() error {
} }
br.endTraced = true br.endTraced = true
} }
invalidateCachesOnBatchResultsError(br.conn, br.b, br.err)
}() }()
if br.err == nil && br.lastRows != nil && br.lastRows.err != nil { if br.err == nil && br.lastRows != nil && br.lastRows.err != nil {
@ -450,20 +441,3 @@ func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, er
br.qqIdx++ br.qqIdx++
return bi.SQL, bi.Arguments, nil return bi.SQL, bi.Arguments, nil
} }
// invalidates statement and description caches on batch results error
func invalidateCachesOnBatchResultsError(conn *Conn, b *Batch, err error) {
if err != nil && conn != nil && b != nil {
if sc := conn.statementCache; sc != nil {
for _, bi := range b.QueuedQueries {
sc.Invalidate(bi.SQL)
}
}
if sc := conn.descriptionCache; sc != nil {
for _, bi := range b.QueuedQueries {
sc.Invalidate(bi.SQL)
}
}
}
}

View File

@ -488,6 +488,7 @@ func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
defer cancel() defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("select n from generate_series(0,5) n") batch.Queue("select n from generate_series(0,5) n")
batch.Queue("select n from generate_series(0,5) n") batch.Queue("select n from generate_series(0,5) n")
@ -538,6 +539,7 @@ func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
}) })
} }
@ -548,6 +550,7 @@ func TestConnSendBatchQueryError(t *testing.T) {
defer cancel() defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0") batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0")
batch.Queue("select n from generate_series(0,5) n") batch.Queue("select n from generate_series(0,5) n")
@ -577,6 +580,7 @@ func TestConnSendBatchQueryError(t *testing.T) {
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") { if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
t.Errorf("br.Close() => %v, want error code %v", err, 22012) t.Errorf("br.Close() => %v, want error code %v", err, 22012)
} }
}) })
} }
@ -587,6 +591,7 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) {
defer cancel() defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
batch := &pgx.Batch{} batch := &pgx.Batch{}
batch.Queue("select 1 1") batch.Queue("select 1 1")
@ -602,6 +607,7 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) {
if err == nil { if err == nil {
t.Error("Expected error") t.Error("Expected error")
} }
}) })
} }
@ -612,6 +618,7 @@ func TestConnSendBatchQueryRowInsert(t *testing.T) {
defer cancel() defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger( sql := `create temporary table ledger(
id serial primary key, id serial primary key,
description varchar not null, description varchar not null,
@ -640,6 +647,7 @@ func TestConnSendBatchQueryRowInsert(t *testing.T) {
} }
br.Close() br.Close()
}) })
} }
@ -650,6 +658,7 @@ func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
defer cancel() defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger( sql := `create temporary table ledger(
id serial primary key, id serial primary key,
description varchar not null, description varchar not null,
@ -678,6 +687,7 @@ func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
} }
br.Close() br.Close()
}) })
} }
@ -688,6 +698,7 @@ func TestTxSendBatch(t *testing.T) {
defer cancel() defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger1( sql := `create temporary table ledger1(
id serial primary key, id serial primary key,
description varchar not null description varchar not null
@ -746,6 +757,7 @@ func TestTxSendBatch(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
}) })
} }
@ -756,6 +768,7 @@ func TestTxSendBatchRollback(t *testing.T) {
defer cancel() defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
sql := `create temporary table ledger1( sql := `create temporary table ledger1(
id serial primary key, id serial primary key,
description varchar not null description varchar not null
@ -782,6 +795,7 @@ func TestTxSendBatchRollback(t *testing.T) {
if count != 0 { if count != 0 {
t.Errorf("count => %v, want %v", count, 0) t.Errorf("count => %v, want %v", count, 0)
} }
}) })
} }
@ -841,6 +855,7 @@ func TestConnBeginBatchDeferredError(t *testing.T) {
defer cancel() defer cancel()
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)") pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
mustExec(t, conn, `create temporary table t ( mustExec(t, conn, `create temporary table t (
@ -879,6 +894,7 @@ func TestConnBeginBatchDeferredError(t *testing.T) {
if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" { if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" {
t.Fatalf("expected error 23505, got %v", err) t.Fatalf("expected error 23505, got %v", err)
} }
}) })
} }

View File

@ -516,6 +516,7 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc
} }
return rowCount, nil return rowCount, nil
} }
func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) { func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) {
@ -534,8 +535,7 @@ func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) {
src := newBenchmarkWriteTableCopyFromSrc(n) src := newBenchmarkWriteTableCopyFromSrc(n)
_, err := multiInsert(conn, "t", _, err := multiInsert(conn, "t",
[]string{ []string{"varchar_1",
"varchar_1",
"varchar_2", "varchar_2",
"varchar_null_1", "varchar_null_1",
"date_1", "date_1",
@ -547,8 +547,7 @@ func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) {
"tstz_2", "tstz_2",
"bool_1", "bool_1",
"bool_2", "bool_2",
"bool_3", "bool_3"},
},
src) src)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
@ -569,8 +568,7 @@ func benchmarkWriteNRowsViaCopy(b *testing.B, n int) {
_, err := conn.CopyFrom(context.Background(), _, err := conn.CopyFrom(context.Background(),
pgx.Identifier{"t"}, pgx.Identifier{"t"},
[]string{ []string{"varchar_1",
"varchar_1",
"varchar_2", "varchar_2",
"varchar_null_1", "varchar_null_1",
"date_1", "date_1",
@ -582,8 +580,7 @@ func benchmarkWriteNRowsViaCopy(b *testing.B, n int) {
"tstz_2", "tstz_2",
"bool_1", "bool_1",
"bool_2", "bool_2",
"bool_3", "bool_3"},
},
src) src)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
@ -614,7 +611,6 @@ func BenchmarkWrite5RowsViaInsert(b *testing.B) {
func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) { func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 5) benchmarkWriteNRowsViaMultiInsert(b, 5)
} }
func BenchmarkWrite5RowsViaBatchInsert(b *testing.B) { func BenchmarkWrite5RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 5) benchmarkWriteNRowsViaBatchInsert(b, 5)
} }
@ -630,7 +626,6 @@ func BenchmarkWrite10RowsViaInsert(b *testing.B) {
func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) { func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 10) benchmarkWriteNRowsViaMultiInsert(b, 10)
} }
func BenchmarkWrite10RowsViaBatchInsert(b *testing.B) { func BenchmarkWrite10RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 10) benchmarkWriteNRowsViaBatchInsert(b, 10)
} }
@ -646,7 +641,6 @@ func BenchmarkWrite100RowsViaInsert(b *testing.B) {
func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) { func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 100) benchmarkWriteNRowsViaMultiInsert(b, 100)
} }
func BenchmarkWrite100RowsViaBatchInsert(b *testing.B) { func BenchmarkWrite100RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 100) benchmarkWriteNRowsViaBatchInsert(b, 100)
} }
@ -678,7 +672,6 @@ func BenchmarkWrite10000RowsViaInsert(b *testing.B) {
func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) { func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) {
benchmarkWriteNRowsViaMultiInsert(b, 10000) benchmarkWriteNRowsViaMultiInsert(b, 10000)
} }
func BenchmarkWrite10000RowsViaBatchInsert(b *testing.B) { func BenchmarkWrite10000RowsViaBatchInsert(b *testing.B) {
benchmarkWriteNRowsViaBatchInsert(b, 10000) benchmarkWriteNRowsViaBatchInsert(b, 10000)
} }
@ -1050,6 +1043,7 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) {
} }
for _, format := range formats { for _, format := range formats {
b.Run(format.name, func(b *testing.B) { b.Run(format.name, func(b *testing.B) {
br := &BenchRowDecoder{} br := &BenchRowDecoder{}
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
rows, err := conn.Query( rows, err := conn.Query(

View File

@ -172,7 +172,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
delete(config.RuntimeParams, "statement_cache_capacity") delete(config.RuntimeParams, "statement_cache_capacity")
n, err := strconv.ParseInt(s, 10, 32) n, err := strconv.ParseInt(s, 10, 32)
if err != nil { if err != nil {
return nil, pgconn.NewParseConfigError(connString, "cannot parse statement_cache_capacity", err) return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err)
} }
statementCacheCapacity = int(n) statementCacheCapacity = int(n)
} }
@ -182,7 +182,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
delete(config.RuntimeParams, "description_cache_capacity") delete(config.RuntimeParams, "description_cache_capacity")
n, err := strconv.ParseInt(s, 10, 32) n, err := strconv.ParseInt(s, 10, 32)
if err != nil { if err != nil {
return nil, pgconn.NewParseConfigError(connString, "cannot parse description_cache_capacity", err) return nil, fmt.Errorf("cannot parse description_cache_capacity: %w", err)
} }
descriptionCacheCapacity = int(n) descriptionCacheCapacity = int(n)
} }
@ -202,7 +202,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
case "simple_protocol": case "simple_protocol":
defaultQueryExecMode = QueryExecModeSimpleProtocol defaultQueryExecMode = QueryExecModeSimpleProtocol
default: default:
return nil, pgconn.NewParseConfigError(connString, "invalid default_query_exec_mode", err) return nil, fmt.Errorf("invalid default_query_exec_mode: %s", s)
} }
} }

View File

@ -412,6 +412,7 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) {
if commandTag.String() != "INSERT 0 1" { if commandTag.String() != "INSERT 0 1" {
t.Fatalf("Unexpected results from Exec: %v", commandTag) t.Fatalf("Unexpected results from Exec: %v", commandTag)
} }
} }
func TestPrepare(t *testing.T) { func TestPrepare(t *testing.T) {
@ -1088,7 +1089,7 @@ func TestLoadRangeType(t *testing.T) {
conn.TypeMap().RegisterType(newRangeType) conn.TypeMap().RegisterType(newRangeType)
conn.TypeMap().RegisterDefaultPgType(pgtype.Range[float64]{}, "examplefloatrange") conn.TypeMap().RegisterDefaultPgType(pgtype.Range[float64]{}, "examplefloatrange")
inputRangeType := pgtype.Range[float64]{ var inputRangeType = pgtype.Range[float64]{
Lower: 1.0, Lower: 1.0,
Upper: 2.0, Upper: 2.0,
LowerType: pgtype.Inclusive, LowerType: pgtype.Inclusive,
@ -1128,7 +1129,7 @@ func TestLoadMultiRangeType(t *testing.T) {
conn.TypeMap().RegisterType(newMultiRangeType) conn.TypeMap().RegisterType(newMultiRangeType)
conn.TypeMap().RegisterDefaultPgType(pgtype.Multirange[pgtype.Range[float64]]{}, "examplefloatmultirange") conn.TypeMap().RegisterDefaultPgType(pgtype.Multirange[pgtype.Range[float64]]{}, "examplefloatmultirange")
inputMultiRangeType := pgtype.Multirange[pgtype.Range[float64]]{ var inputMultiRangeType = pgtype.Multirange[pgtype.Range[float64]]{
{ {
Lower: 1.0, Lower: 1.0,
Upper: 2.0, Upper: 2.0,
@ -1292,177 +1293,6 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
ensureConnValid(t, conn) ensureConnValid(t, conn)
} }
func TestStmtCacheInvalidationConnWithBatch(t *testing.T) {
ctx := context.Background()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
if conn.PgConn().ParameterStatus("crdb_version") != "" {
t.Skip("Test fails due to different CRDB behavior")
}
// create a table and fill it with some data
_, err := conn.Exec(ctx, `
DROP TABLE IF EXISTS drop_cols;
CREATE TABLE drop_cols (
id SERIAL PRIMARY KEY NOT NULL,
f1 int NOT NULL,
f2 int NOT NULL
);
`)
require.NoError(t, err)
_, err = conn.Exec(ctx, "INSERT INTO drop_cols (f1, f2) VALUES (1, 2)")
require.NoError(t, err)
getSQL := "SELECT * FROM drop_cols WHERE id = $1"
// This query will populate the statement cache. We don't care about the result.
rows, err := conn.Query(ctx, getSQL, 1)
require.NoError(t, err)
rows.Close()
require.NoError(t, rows.Err())
// Now, change the schema of the table out from under the statement, making it invalid.
_, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
require.NoError(t, err)
// We must get an error the first time we try to re-execute a bad statement.
// It is up to the application to determine if it wants to try again. We punt to
// the application because there is no clear recovery path in the case of failed transactions
// or batch operations and because automatic retry is tricky and we don't want to get
// it wrong at such an importaint layer of the stack.
batch := &pgx.Batch{}
batch.Queue(getSQL, 1)
br := conn.SendBatch(ctx, batch)
rows, err = br.Query()
require.Error(t, err)
rows.Next()
nextErr := rows.Err()
rows.Close()
err = br.Close()
require.Error(t, err)
for _, err := range []error{nextErr, rows.Err()} {
if err == nil {
t.Fatal(`expected "cached plan must not change result type": no error`)
}
if !strings.Contains(err.Error(), "cached plan must not change result type") {
t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
}
}
// On retry, the statement should have been flushed from the cache.
batch = &pgx.Batch{}
batch.Queue(getSQL, 1)
br = conn.SendBatch(ctx, batch)
rows, err = br.Query()
require.NoError(t, err)
rows.Next()
err = rows.Err()
require.NoError(t, err)
rows.Close()
require.NoError(t, rows.Err())
err = br.Close()
require.NoError(t, err)
ensureConnValid(t, conn)
}
func TestStmtCacheInvalidationTxWithBatch(t *testing.T) {
ctx := context.Background()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
if conn.PgConn().ParameterStatus("crdb_version") != "" {
t.Skip("Server has non-standard prepare in errored transaction behavior (https://github.com/cockroachdb/cockroach/issues/84140)")
}
// create a table and fill it with some data
_, err := conn.Exec(ctx, `
DROP TABLE IF EXISTS drop_cols;
CREATE TABLE drop_cols (
id SERIAL PRIMARY KEY NOT NULL,
f1 int NOT NULL,
f2 int NOT NULL
);
`)
require.NoError(t, err)
_, err = conn.Exec(ctx, "INSERT INTO drop_cols (f1, f2) VALUES (1, 2)")
require.NoError(t, err)
tx, err := conn.Begin(ctx)
require.NoError(t, err)
getSQL := "SELECT * FROM drop_cols WHERE id = $1"
// This query will populate the statement cache. We don't care about the result.
rows, err := tx.Query(ctx, getSQL, 1)
require.NoError(t, err)
rows.Close()
require.NoError(t, rows.Err())
// Now, change the schema of the table out from under the statement, making it invalid.
_, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
require.NoError(t, err)
// We must get an error the first time we try to re-execute a bad statement.
// It is up to the application to determine if it wants to try again. We punt to
// the application because there is no clear recovery path in the case of failed transactions
// or batch operations and because automatic retry is tricky and we don't want to get
// it wrong at such an importaint layer of the stack.
batch := &pgx.Batch{}
batch.Queue(getSQL, 1)
br := tx.SendBatch(ctx, batch)
rows, err = br.Query()
require.Error(t, err)
rows.Next()
nextErr := rows.Err()
rows.Close()
err = br.Close()
require.Error(t, err)
for _, err := range []error{nextErr, rows.Err()} {
if err == nil {
t.Fatal(`expected "cached plan must not change result type": no error`)
}
if !strings.Contains(err.Error(), "cached plan must not change result type") {
t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
}
}
batch = &pgx.Batch{}
batch.Queue(getSQL, 1)
br = tx.SendBatch(ctx, batch)
rows, err = br.Query()
require.Error(t, err)
rows.Close()
err = rows.Err()
// Retries within the same transaction are errors (really anything except a rollback
// will be an error in this transaction).
require.Error(t, err)
rows.Close()
err = br.Close()
require.Error(t, err)
err = tx.Rollback(ctx)
require.NoError(t, err)
// once we've rolled back, retries will work
batch = &pgx.Batch{}
batch.Queue(getSQL, 1)
br = conn.SendBatch(ctx, batch)
rows, err = br.Query()
require.NoError(t, err)
rows.Next()
err = rows.Err()
require.NoError(t, err)
rows.Close()
err = br.Close()
require.NoError(t, err)
ensureConnValid(t, conn)
}
func TestInsertDurationInterval(t *testing.T) { func TestInsertDurationInterval(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel() defer cancel()

View File

@ -76,6 +76,7 @@ func TestConnCopyWithAllQueryExecModes(t *testing.T) {
} }
func TestConnCopyWithKnownOIDQueryExecModes(t *testing.T) { func TestConnCopyWithKnownOIDQueryExecModes(t *testing.T) {
for _, mode := range pgxtest.KnownOIDQueryExecModes { for _, mode := range pgxtest.KnownOIDQueryExecModes {
t.Run(mode.String(), func(t *testing.T) { t.Run(mode.String(), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)

View File

@ -31,6 +31,7 @@ func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
} }
return nil return nil
} }
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or // Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or

View File

@ -72,4 +72,5 @@ func testPgbouncer(t *testing.T, config *pgx.ConnConfig, workers, iterations int
for i := 0; i < workers; i++ { for i := 0; i < workers; i++ {
<-doneChan <-doneChan
} }
} }

View File

@ -263,7 +263,7 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte {
return buf return buf
} }
func computeServerSignature(saltedPassword, authMessage []byte) []byte { func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
serverKey := computeHMAC(saltedPassword, []byte("Server Key")) serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
serverSignature := computeHMAC(serverKey, authMessage) serverSignature := computeHMAC(serverKey, authMessage)
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature))) buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))

View File

@ -78,6 +78,7 @@ func BenchmarkExec(b *testing.B) {
} }
} }
_, err = rr.Close() _, err = rr.Close()
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -126,6 +127,7 @@ func BenchmarkExecPossibleToCancel(b *testing.B) {
} }
} }
_, err = rr.Close() _, err = rr.Close()
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -182,6 +184,7 @@ func BenchmarkExecPrepared(b *testing.B) {
} }
} }
_, err = rr.Close() _, err = rr.Close()
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -224,6 +227,7 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
} }
} }
_, err = rr.Close() _, err = rr.Close()
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@ -23,11 +23,9 @@ import (
"github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgproto3"
) )
type ( type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error type GetSSLPasswordFunc func(ctx context.Context) string
GetSSLPasswordFunc func(ctx context.Context) string
)
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A // Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A
// manually initialized Config will cause ConnectConfig to panic. // manually initialized Config will cause ConnectConfig to panic.
@ -181,7 +179,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// //
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated // ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
// values that will be tried in order. This can be used as part of a high availability system. See // values that will be tried in order. This can be used as part of a high availability system. See
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
// //
// # Example URL // # Example URL
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb // postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
@ -208,9 +206,9 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// PGTARGETSESSIONATTRS // PGTARGETSESSIONATTRS
// PGTZ // PGTZ
// //
// See http://www.postgresql.org/docs/current/static/libpq-envars.html for details on the meaning of environment variables. // See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
// //
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are // See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
// usually but not always the environment variable name downcased and without the "PG" prefix. // usually but not always the environment variable name downcased and without the "PG" prefix.
// //
// Important Security Notes: // Important Security Notes:
@ -218,7 +216,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if // ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
// not set. // not set.
// //
// See http://www.postgresql.org/docs/current/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of // See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
// security each sslmode provides. // security each sslmode provides.
// //
// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of // The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
@ -715,7 +713,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
// According to PostgreSQL documentation, if a root CA file exists, // According to PostgreSQL documentation, if a root CA file exists,
// the behavior of sslmode=require should be the same as that of verify-ca // the behavior of sslmode=require should be the same as that of verify-ca
// //
// See https://www.postgresql.org/docs/current/libpq-ssl.html // See https://www.postgresql.org/docs/12/libpq-ssl.html
if sslrootcert != "" { if sslrootcert != "" {
goto nextCase goto nextCase
} }
@ -786,8 +784,8 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
if sslpassword != "" { if sslpassword != "" {
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword)) decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
} }
// if sslpassword not provided or has decryption error when use it //if sslpassword not provided or has decryption error when use it
// try to find sslpassword with callback function //try to find sslpassword with callback function
if sslpassword == "" || decryptedError != nil { if sslpassword == "" || decryptedError != nil {
if parseConfigOptions.GetSSLPassword != nil { if parseConfigOptions.GetSSLPassword != nil {
sslpassword = parseConfigOptions.GetSSLPassword(context.Background()) sslpassword = parseConfigOptions.GetSSLPassword(context.Background())

View File

@ -133,6 +133,7 @@ func TestParseConfig(t *testing.T) {
name: "sslmode prefer", name: "sslmode prefer",
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer", connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer",
config: &pgconn.Config{ config: &pgconn.Config{
User: "jack", User: "jack",
Password: "secret", Password: "secret",
Host: "localhost", Host: "localhost",
@ -566,8 +567,7 @@ func TestParseConfig(t *testing.T) {
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: "bar", ServerName: "bar",
}, }},
},
{ {
Host: "bar", Host: "bar",
Port: defaultPort, Port: defaultPort,
@ -579,8 +579,7 @@ func TestParseConfig(t *testing.T) {
TLSConfig: &tls.Config{ TLSConfig: &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
ServerName: "baz", ServerName: "baz",
}, }},
},
{ {
Host: "baz", Host: "baz",
Port: defaultPort, Port: defaultPort,
@ -1024,7 +1023,7 @@ func TestParseConfigReadsPgPassfile(t *testing.T) {
t.Parallel() t.Parallel()
tfName := filepath.Join(t.TempDir(), "config") tfName := filepath.Join(t.TempDir(), "config")
err := os.WriteFile(tfName, []byte("test1:5432:curlydb:curly:nyuknyuknyuk"), 0o600) err := os.WriteFile(tfName, []byte("test1:5432:curlydb:curly:nyuknyuknyuk"), 0600)
require.NoError(t, err) require.NoError(t, err)
connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tfName) connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tfName)
@ -1062,7 +1061,7 @@ host = def.example.com
dbname = defdb dbname = defdb
user = defuser user = defuser
application_name = spaced string application_name = spaced string
`), 0o600) `), 0600)
require.NoError(t, err) require.NoError(t, err)
defaultPort := getDefaultPort(t) defaultPort := getDefaultPort(t)

View File

@ -27,7 +27,7 @@ func Timeout(err error) bool {
} }
// PgError represents an error reported by the PostgreSQL server. See // PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/current/static/protocol-error-fields.html for // http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
// detailed field description. // detailed field description.
type PgError struct { type PgError struct {
Severity string Severity string
@ -112,14 +112,6 @@ type ParseConfigError struct {
err error err error
} }
func NewParseConfigError(conn, msg string, err error) error {
return &ParseConfigError{
ConnString: conn,
msg: msg,
err: err,
}
}
func (e *ParseConfigError) Error() string { func (e *ParseConfigError) Error() string {
// Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only // Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only
// return a static string. That would ensure that the error message cannot leak a password. The ConnString field would // return a static string. That would ensure that the error message cannot leak a password. The ConnString field would

View File

@ -1,3 +1,11 @@
// File export_test exports some methods for better testing. // File export_test exports some methods for better testing.
package pgconn package pgconn
func NewParseConfigError(conn, msg string, err error) error {
return &ParseConfigError{
ConnString: conn,
msg: msg,
err: err,
}
}

View File

@ -28,7 +28,7 @@ func RegisterGSSProvider(newGSSArg NewGSSFunc) {
// GSS provides GSSAPI authentication (e.g., Kerberos). // GSS provides GSSAPI authentication (e.g., Kerberos).
type GSS interface { type GSS interface {
GetInitToken(host, service string) ([]byte, error) GetInitToken(host string, service string) ([]byte, error)
GetInitTokenFromSPN(spn string) ([]byte, error) GetInitTokenFromSPN(spn string) ([]byte, error)
Continue(inToken []byte) (done bool, outToken []byte, err error) Continue(inToken []byte) (done bool, outToken []byte, err error)
} }

View File

@ -135,7 +135,7 @@ func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptio
// //
// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An
// authentication error will terminate the chain of attempts (like libpq: // authentication error will terminate the chain of attempts (like libpq:
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error.
func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) { func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) {
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
// zero values. // zero values.
@ -991,8 +991,7 @@ func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice {
// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel // CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel
// request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there
// is no way to be sure a query was canceled. // is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9
// See https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-CANCELING-REQUESTS
func (pgConn *PgConn) CancelRequest(ctx context.Context) error { func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing
// the connection config. This is important in high availability configurations where fallback connections may be // the connection config. This is important in high availability configurations where fallback connections may be
@ -1141,7 +1140,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
// binary format. If resultFormats is nil all results will be in text format. // binary format. If resultFormats is nil all results will be in text format.
// //
// ResultReader must be closed before PgConn can be used again. // ResultReader must be closed before PgConn can be used again.
func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) *ResultReader { func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader {
result := pgConn.execExtendedPrefix(ctx, paramValues) result := pgConn.execExtendedPrefix(ctx, paramValues)
if result.closed { if result.closed {
return result return result
@ -1167,7 +1166,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
// binary format. If resultFormats is nil all results will be in text format. // binary format. If resultFormats is nil all results will be in text format.
// //
// ResultReader must be closed before PgConn can be used again. // ResultReader must be closed before PgConn can be used again.
func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) *ResultReader { func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader {
result := pgConn.execExtendedPrefix(ctx, paramValues) result := pgConn.execExtendedPrefix(ctx, paramValues)
if result.closed { if result.closed {
return result return result
@ -1374,14 +1373,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
close(pgConn.cleanupDone) close(pgConn.cleanupDone)
return CommandTag{}, normalizeTimeoutError(ctx, err) return CommandTag{}, normalizeTimeoutError(ctx, err)
} }
// peekMessage never returns err in the bufferingReceive mode - it only forwards the bufferingReceive variables. msg, _ := pgConn.receiveMessage()
// Therefore, the only case for receiveMessage to return err is during handling of the ErrorResponse message type
// and using pgOnError handler to determine the connection is no longer valid (and thus closing the conn).
msg, serverError := pgConn.receiveMessage()
if serverError != nil {
close(abortCopyChan)
return CommandTag{}, serverError
}
switch msg := msg.(type) { switch msg := msg.(type) {
case *pgproto3.ErrorResponse: case *pgproto3.ErrorResponse:
@ -1720,7 +1712,7 @@ type Batch struct {
} }
// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) { func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
if batch.err != nil { if batch.err != nil {
return return
} }
@ -1733,7 +1725,7 @@ func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uin
} }
// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. // ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) { func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
if batch.err != nil { if batch.err != nil {
return return
} }
@ -2209,7 +2201,7 @@ func (p *Pipeline) SendDeallocate(name string) {
} }
// SendQueryParams is the pipeline version of *PgConn.QueryParams. // SendQueryParams is the pipeline version of *PgConn.QueryParams.
func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) { func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
if p.closed { if p.closed {
return return
} }
@ -2222,7 +2214,7 @@ func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs [
} }
// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared. // SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared.
func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) { func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
if p.closed { if p.closed {
return return
} }

View File

@ -9,7 +9,7 @@ import (
func TestCommandTag(t *testing.T) { func TestCommandTag(t *testing.T) {
t.Parallel() t.Parallel()
tests := []struct { var tests = []struct {
commandTag CommandTag commandTag CommandTag
rowsAffected int64 rowsAffected int64
isInsert bool isInsert bool

View File

@ -2130,63 +2130,6 @@ func TestConnCopyFromPrecanceled(t *testing.T) {
ensureConnValid(t, pgConn) ensureConnValid(t, pgConn)
} }
// https://github.com/jackc/pgx/issues/2364
func TestConnCopyFromConnectionTerminated(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server does not support pg_terminate_backend")
}
closerConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
time.AfterFunc(500*time.Millisecond, func() {
// defer inside of AfterFunc instead of outer test function because outer function can finish while Read is still in
// progress which could cause closerConn to be closed too soon.
defer closeConn(t, closerConn)
err := closerConn.ExecParams(ctx, "select pg_terminate_backend($1)", [][]byte{[]byte(fmt.Sprintf("%d", pgConn.PID()))}, nil, nil, nil).Read().Err
require.NoError(t, err)
})
_, err = pgConn.Exec(ctx, `create temporary table foo(
a int4,
b varchar
)`).ReadAll()
require.NoError(t, err)
r, w := io.Pipe()
go func() {
for i := 0; i < 5_000; i++ {
a := strconv.Itoa(i)
b := "foo " + a + " bar"
_, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
if err != nil {
return
}
time.Sleep(time.Millisecond)
}
}()
copySql := "COPY foo FROM STDIN WITH (FORMAT csv)"
ct, err := pgConn.CopyFrom(ctx, r, copySql)
assert.Equal(t, int64(0), ct.RowsAffected())
assert.Error(t, err)
assert.True(t, pgConn.IsClosed())
select {
case <-pgConn.CleanupDone():
case <-time.After(5 * time.Second):
t.Fatal("Connection cleanup exceeded maximum time")
}
}
func TestConnCopyFromGzipReader(t *testing.T) { func TestConnCopyFromGzipReader(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -9,7 +9,8 @@ import (
) )
// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required. // AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required.
type AuthenticationCleartextPassword struct{} type AuthenticationCleartextPassword struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend. // Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationCleartextPassword) Backend() {} func (*AuthenticationCleartextPassword) Backend() {}

View File

@ -9,7 +9,8 @@ import (
) )
// AuthenticationOk is a message sent from the backend indicating that authentication was successful. // AuthenticationOk is a message sent from the backend indicating that authentication was successful.
type AuthenticationOk struct{} type AuthenticationOk struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend. // Backend identifies this message as sendable by the PostgreSQL backend.
func (*AuthenticationOk) Backend() {} func (*AuthenticationOk) Backend() {}

View File

@ -4,7 +4,8 @@ import (
"encoding/json" "encoding/json"
) )
type CopyDone struct{} type CopyDone struct {
}
// Backend identifies this message as sendable by the PostgreSQL backend. // Backend identifies this message as sendable by the PostgreSQL backend.
func (*CopyDone) Backend() {} func (*CopyDone) Backend() {}

View File

@ -10,7 +10,8 @@ import (
const gssEncReqNumber = 80877104 const gssEncReqNumber = 80877104
type GSSEncRequest struct{} type GSSEncRequest struct {
}
// Frontend identifies this message as sendable by a PostgreSQL frontend. // Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*GSSEncRequest) Frontend() {} func (*GSSEncRequest) Frontend() {}

View File

@ -332,7 +332,7 @@ func TestJSONUnmarshalRowDescription(t *testing.T) {
} }
func TestJSONUnmarshalBind(t *testing.T) { func TestJSONUnmarshalBind(t *testing.T) {
testCases := []struct { var testCases = []struct {
desc string desc string
data []byte data []byte
}{ }{
@ -348,7 +348,7 @@ func TestJSONUnmarshalBind(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) { t.Run(tc.desc, func(t *testing.T) {
want := Bind{ var want = Bind{
PreparedStatement: "lrupsc_1_0", PreparedStatement: "lrupsc_1_0",
ParameterFormatCodes: []int16{0}, ParameterFormatCodes: []int16{0},
Parameters: [][]byte{[]byte("ABC-123")}, Parameters: [][]byte{[]byte("ABC-123")},

View File

@ -56,6 +56,7 @@ func (*RowDescription) Backend() {}
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message // Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
// type identifier and 4 byte message length. // type identifier and 4 byte message length.
func (dst *RowDescription) Decode(src []byte) error { func (dst *RowDescription) Decode(src []byte) error {
if len(src) < 2 { if len(src) < 2 {
return &invalidMessageFormatErr{messageType: "RowDescription"} return &invalidMessageFormatErr{messageType: "RowDescription"}
} }

View File

@ -10,7 +10,8 @@ import (
const sslRequestNumber = 80877103 const sslRequestNumber = 80877103
type SSLRequest struct{} type SSLRequest struct {
}
// Frontend identifies this message as sendable by a PostgreSQL frontend. // Frontend identifies this message as sendable by a PostgreSQL frontend.
func (*SSLRequest) Frontend() {} func (*SSLRequest) Frontend() {}

View File

@ -374,8 +374,8 @@ func quoteArrayElementIfNeeded(src string) string {
return src return src
} }
// Array represents a PostgreSQL array for T. It implements the [ArrayGetter] and [ArraySetter] interfaces. It preserves // Array represents a PostgreSQL array for T. It implements the ArrayGetter and ArraySetter interfaces. It preserves
// PostgreSQL dimensions and custom lower bounds. Use [FlatArray] if these are not needed. // PostgreSQL dimensions and custom lower bounds. Use FlatArray if these are not needed.
type Array[T any] struct { type Array[T any] struct {
Elements []T Elements []T
Dims []ArrayDimension Dims []ArrayDimension
@ -419,8 +419,8 @@ func (a Array[T]) ScanIndexType() any {
return new(T) return new(T)
} }
// FlatArray implements the [ArrayGetter] and [ArraySetter] interfaces for any slice of T. It ignores PostgreSQL dimensions // FlatArray implements the ArrayGetter and ArraySetter interfaces for any slice of T. It ignores PostgreSQL dimensions
// and custom lower bounds. Use [Array] to preserve these. // and custom lower bounds. Use Array to preserve these.
type FlatArray[T any] []T type FlatArray[T any] []T
func (a FlatArray[T]) Dimensions() []ArrayDimension { func (a FlatArray[T]) Dimensions() []ArrayDimension {

View File

@ -256,6 +256,7 @@ func TestArrayCodecScanMultipleDimensions(t *testing.T) {
skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)") skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)")
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
rows, err := conn.Query(ctx, `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`) rows, err := conn.Query(ctx, `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`)
require.NoError(t, err) require.NoError(t, err)

View File

@ -23,18 +23,16 @@ type Bits struct {
Valid bool Valid bool
} }
// ScanBits implements the [BitsScanner] interface.
func (b *Bits) ScanBits(v Bits) error { func (b *Bits) ScanBits(v Bits) error {
*b = v *b = v
return nil return nil
} }
// BitsValue implements the [BitsValuer] interface.
func (b Bits) BitsValue() (Bits, error) { func (b Bits) BitsValue() (Bits, error) {
return b, nil return b, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *Bits) Scan(src any) error { func (dst *Bits) Scan(src any) error {
if src == nil { if src == nil {
*dst = Bits{} *dst = Bits{}
@ -49,7 +47,7 @@ func (dst *Bits) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Bits) Value() (driver.Value, error) { func (src Bits) Value() (driver.Value, error) {
if !src.Valid { if !src.Valid {
return nil, nil return nil, nil
@ -129,6 +127,7 @@ func (encodePlanBitsCodecText) Encode(value any, buf []byte) (newBuf []byte, err
} }
func (BitsCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (BitsCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -22,18 +22,16 @@ type Bool struct {
Valid bool Valid bool
} }
// ScanBool implements the [BoolScanner] interface.
func (b *Bool) ScanBool(v Bool) error { func (b *Bool) ScanBool(v Bool) error {
*b = v *b = v
return nil return nil
} }
// BoolValue implements the [BoolValuer] interface.
func (b Bool) BoolValue() (Bool, error) { func (b Bool) BoolValue() (Bool, error) {
return b, nil return b, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *Bool) Scan(src any) error { func (dst *Bool) Scan(src any) error {
if src == nil { if src == nil {
*dst = Bool{} *dst = Bool{}
@ -63,7 +61,7 @@ func (dst *Bool) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Bool) Value() (driver.Value, error) { func (src Bool) Value() (driver.Value, error) {
if !src.Valid { if !src.Valid {
return nil, nil return nil, nil
@ -72,7 +70,6 @@ func (src Bool) Value() (driver.Value, error) {
return src.Bool, nil return src.Bool, nil
} }
// MarshalJSON implements the [encoding/json.Marshaler] interface.
func (src Bool) MarshalJSON() ([]byte, error) { func (src Bool) MarshalJSON() ([]byte, error) {
if !src.Valid { if !src.Valid {
return []byte("null"), nil return []byte("null"), nil
@ -85,7 +82,6 @@ func (src Bool) MarshalJSON() ([]byte, error) {
} }
} }
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
func (dst *Bool) UnmarshalJSON(b []byte) error { func (dst *Bool) UnmarshalJSON(b []byte) error {
var v *bool var v *bool
err := json.Unmarshal(b, &v) err := json.Unmarshal(b, &v)
@ -204,6 +200,7 @@ func (encodePlanBoolCodecTextBool) Encode(value any, buf []byte) (newBuf []byte,
} }
func (BoolCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (BoolCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {
@ -331,7 +328,7 @@ func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error {
return s.ScanBool(Bool{Bool: v, Valid: true}) return s.ScanBool(Bool{Bool: v, Valid: true})
} }
// https://www.postgresql.org/docs/current/datatype-boolean.html // https://www.postgresql.org/docs/11/datatype-boolean.html
func planTextToBool(src []byte) (bool, error) { func planTextToBool(src []byte) (bool, error) {
s := string(bytes.ToLower(bytes.TrimSpace(src))) s := string(bytes.ToLower(bytes.TrimSpace(src)))

View File

@ -24,18 +24,16 @@ type Box struct {
Valid bool Valid bool
} }
// ScanBox implements the [BoxScanner] interface.
func (b *Box) ScanBox(v Box) error { func (b *Box) ScanBox(v Box) error {
*b = v *b = v
return nil return nil
} }
// BoxValue implements the [BoxValuer] interface.
func (b Box) BoxValue() (Box, error) { func (b Box) BoxValue() (Box, error) {
return b, nil return b, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *Box) Scan(src any) error { func (dst *Box) Scan(src any) error {
if src == nil { if src == nil {
*dst = Box{} *dst = Box{}
@ -50,7 +48,7 @@ func (dst *Box) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Box) Value() (driver.Value, error) { func (src Box) Value() (driver.Value, error) {
if !src.Valid { if !src.Valid {
return nil, nil return nil, nil
@ -129,6 +127,7 @@ func (encodePlanBoxCodecText) Encode(value any, buf []byte) (newBuf []byte, err
} }
func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -527,7 +527,6 @@ func (w *netIPNetWrapper) ScanNetipPrefix(v netip.Prefix) error {
return nil return nil
} }
func (w netIPNetWrapper) NetipPrefixValue() (netip.Prefix, error) { func (w netIPNetWrapper) NetipPrefixValue() (netip.Prefix, error) {
ip, ok := netip.AddrFromSlice(w.IP) ip, ok := netip.AddrFromSlice(w.IP)
if !ok { if !ok {
@ -882,6 +881,7 @@ func (a *anyMultiDimSliceArray) SetDimensions(dimensions []ArrayDimension) error
return nil return nil
} }
} }
func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type, dimensions []ArrayDimension, flatSlice reflect.Value, flatSliceIdx int) reflect.Value { func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type, dimensions []ArrayDimension, flatSlice reflect.Value, flatSliceIdx int) reflect.Value {

View File

@ -148,6 +148,7 @@ func (encodePlanBytesCodecTextBytesValuer) Encode(value any, buf []byte) (newBuf
} }
func (ByteaCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (ByteaCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -25,18 +25,16 @@ type Circle struct {
Valid bool Valid bool
} }
// ScanCircle implements the [CircleScanner] interface.
func (c *Circle) ScanCircle(v Circle) error { func (c *Circle) ScanCircle(v Circle) error {
*c = v *c = v
return nil return nil
} }
// CircleValue implements the [CircleValuer] interface.
func (c Circle) CircleValue() (Circle, error) { func (c Circle) CircleValue() (Circle, error) {
return c, nil return c, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *Circle) Scan(src any) error { func (dst *Circle) Scan(src any) error {
if src == nil { if src == nil {
*dst = Circle{} *dst = Circle{}
@ -51,7 +49,7 @@ func (dst *Circle) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Circle) Value() (driver.Value, error) { func (src Circle) Value() (driver.Value, error) {
if !src.Valid { if !src.Valid {
return nil, nil return nil, nil

View File

@ -276,6 +276,7 @@ func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byt
default: default:
return nil, fmt.Errorf("unknown format code %d", format) return nil, fmt.Errorf("unknown format code %d", format)
} }
} }
type CompositeBinaryScanner struct { type CompositeBinaryScanner struct {

View File

@ -12,6 +12,7 @@ import (
func TestCompositeCodecTranscode(t *testing.T) { func TestCompositeCodecTranscode(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `drop type if exists ct_test; _, err := conn.Exec(ctx, `drop type if exists ct_test;
create type ct_test as ( create type ct_test as (
@ -89,6 +90,7 @@ func (p *point3d) ScanIndex(i int) any {
func TestCompositeCodecTranscodeStruct(t *testing.T) { func TestCompositeCodecTranscodeStruct(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `drop type if exists point3d; _, err := conn.Exec(ctx, `drop type if exists point3d;
create type point3d as ( create type point3d as (
@ -123,6 +125,7 @@ create type point3d as (
func TestCompositeCodecTranscodeStructWrapper(t *testing.T) { func TestCompositeCodecTranscodeStructWrapper(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `drop type if exists point3d; _, err := conn.Exec(ctx, `drop type if exists point3d;
create type point3d as ( create type point3d as (
@ -161,6 +164,7 @@ create type point3d as (
func TestCompositeCodecDecodeValue(t *testing.T) { func TestCompositeCodecDecodeValue(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `drop type if exists point3d; _, err := conn.Exec(ctx, `drop type if exists point3d;
create type point3d as ( create type point3d as (
@ -205,6 +209,7 @@ func TestCompositeCodecTranscodeStructWrapperForTable(t *testing.T) {
skipCockroachDB(t, "Server does not support composite types from table definitions") skipCockroachDB(t, "Server does not support composite types from table definitions")
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `drop table if exists point3d; _, err := conn.Exec(ctx, `drop table if exists point3d;
create table point3d ( create table point3d (

View File

@ -26,13 +26,11 @@ type Date struct {
Valid bool Valid bool
} }
// ScanDate implements the [DateScanner] interface.
func (d *Date) ScanDate(v Date) error { func (d *Date) ScanDate(v Date) error {
*d = v *d = v
return nil return nil
} }
// DateValue implements the [DateValuer] interface.
func (d Date) DateValue() (Date, error) { func (d Date) DateValue() (Date, error) {
return d, nil return d, nil
} }
@ -42,7 +40,7 @@ const (
infinityDayOffset = 2147483647 infinityDayOffset = 2147483647
) )
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *Date) Scan(src any) error { func (dst *Date) Scan(src any) error {
if src == nil { if src == nil {
*dst = Date{} *dst = Date{}
@ -60,7 +58,7 @@ func (dst *Date) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Date) Value() (driver.Value, error) { func (src Date) Value() (driver.Value, error) {
if !src.Valid { if !src.Valid {
return nil, nil return nil, nil
@ -72,7 +70,6 @@ func (src Date) Value() (driver.Value, error) {
return src.Time, nil return src.Time, nil
} }
// MarshalJSON implements the [encoding/json.Marshaler] interface.
func (src Date) MarshalJSON() ([]byte, error) { func (src Date) MarshalJSON() ([]byte, error) {
if !src.Valid { if !src.Valid {
return []byte("null"), nil return []byte("null"), nil
@ -92,7 +89,6 @@ func (src Date) MarshalJSON() ([]byte, error) {
return json.Marshal(s) return json.Marshal(s)
} }
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
func (dst *Date) UnmarshalJSON(b []byte) error { func (dst *Date) UnmarshalJSON(b []byte) error {
var s *string var s *string
err := json.Unmarshal(b, &s) err := json.Unmarshal(b, &s)
@ -227,6 +223,7 @@ func (encodePlanDateCodecText) Encode(value any, buf []byte) (newBuf []byte, err
} }
func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -10,6 +10,7 @@ import (
func TestEnumCodec(t *testing.T) { func TestEnumCodec(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `drop type if exists enum_test; _, err := conn.Exec(ctx, `drop type if exists enum_test;
create type enum_test as enum ('foo', 'bar', 'baz');`) create type enum_test as enum ('foo', 'bar', 'baz');`)
@ -46,6 +47,7 @@ create type enum_test as enum ('foo', 'bar', 'baz');`)
func TestEnumCodecValues(t *testing.T) { func TestEnumCodecValues(t *testing.T) {
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
_, err := conn.Exec(ctx, `drop type if exists enum_test; _, err := conn.Exec(ctx, `drop type if exists enum_test;
create type enum_test as enum ('foo', 'bar', 'baz');`) create type enum_test as enum ('foo', 'bar', 'baz');`)

View File

@ -16,29 +16,26 @@ type Float4 struct {
Valid bool Valid bool
} }
// ScanFloat64 implements the [Float64Scanner] interface. // ScanFloat64 implements the Float64Scanner interface.
func (f *Float4) ScanFloat64(n Float8) error { func (f *Float4) ScanFloat64(n Float8) error {
*f = Float4{Float32: float32(n.Float64), Valid: n.Valid} *f = Float4{Float32: float32(n.Float64), Valid: n.Valid}
return nil return nil
} }
// Float64Value implements the [Float64Valuer] interface.
func (f Float4) Float64Value() (Float8, error) { func (f Float4) Float64Value() (Float8, error) {
return Float8{Float64: float64(f.Float32), Valid: f.Valid}, nil return Float8{Float64: float64(f.Float32), Valid: f.Valid}, nil
} }
// ScanInt64 implements the [Int64Scanner] interface.
func (f *Float4) ScanInt64(n Int8) error { func (f *Float4) ScanInt64(n Int8) error {
*f = Float4{Float32: float32(n.Int64), Valid: n.Valid} *f = Float4{Float32: float32(n.Int64), Valid: n.Valid}
return nil return nil
} }
// Int64Value implements the [Int64Valuer] interface.
func (f Float4) Int64Value() (Int8, error) { func (f Float4) Int64Value() (Int8, error) {
return Int8{Int64: int64(f.Float32), Valid: f.Valid}, nil return Int8{Int64: int64(f.Float32), Valid: f.Valid}, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (f *Float4) Scan(src any) error { func (f *Float4) Scan(src any) error {
if src == nil { if src == nil {
*f = Float4{} *f = Float4{}
@ -61,7 +58,7 @@ func (f *Float4) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (f Float4) Value() (driver.Value, error) { func (f Float4) Value() (driver.Value, error) {
if !f.Valid { if !f.Valid {
return nil, nil return nil, nil
@ -69,7 +66,6 @@ func (f Float4) Value() (driver.Value, error) {
return float64(f.Float32), nil return float64(f.Float32), nil
} }
// MarshalJSON implements the [encoding/json.Marshaler] interface.
func (f Float4) MarshalJSON() ([]byte, error) { func (f Float4) MarshalJSON() ([]byte, error) {
if !f.Valid { if !f.Valid {
return []byte("null"), nil return []byte("null"), nil
@ -77,7 +73,6 @@ func (f Float4) MarshalJSON() ([]byte, error) {
return json.Marshal(f.Float32) return json.Marshal(f.Float32)
} }
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
func (f *Float4) UnmarshalJSON(b []byte) error { func (f *Float4) UnmarshalJSON(b []byte) error {
var n *float32 var n *float32
err := json.Unmarshal(b, &n) err := json.Unmarshal(b, &n)
@ -175,6 +170,7 @@ func (encodePlanFloat4CodecBinaryInt64Valuer) Encode(value any, buf []byte) (new
} }
func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -24,29 +24,26 @@ type Float8 struct {
Valid bool Valid bool
} }
// ScanFloat64 implements the [Float64Scanner] interface. // ScanFloat64 implements the Float64Scanner interface.
func (f *Float8) ScanFloat64(n Float8) error { func (f *Float8) ScanFloat64(n Float8) error {
*f = n *f = n
return nil return nil
} }
// Float64Value implements the [Float64Valuer] interface.
func (f Float8) Float64Value() (Float8, error) { func (f Float8) Float64Value() (Float8, error) {
return f, nil return f, nil
} }
// ScanInt64 implements the [Int64Scanner] interface.
func (f *Float8) ScanInt64(n Int8) error { func (f *Float8) ScanInt64(n Int8) error {
*f = Float8{Float64: float64(n.Int64), Valid: n.Valid} *f = Float8{Float64: float64(n.Int64), Valid: n.Valid}
return nil return nil
} }
// Int64Value implements the [Int64Valuer] interface.
func (f Float8) Int64Value() (Int8, error) { func (f Float8) Int64Value() (Int8, error) {
return Int8{Int64: int64(f.Float64), Valid: f.Valid}, nil return Int8{Int64: int64(f.Float64), Valid: f.Valid}, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (f *Float8) Scan(src any) error { func (f *Float8) Scan(src any) error {
if src == nil { if src == nil {
*f = Float8{} *f = Float8{}
@ -69,7 +66,7 @@ func (f *Float8) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (f Float8) Value() (driver.Value, error) { func (f Float8) Value() (driver.Value, error) {
if !f.Valid { if !f.Valid {
return nil, nil return nil, nil
@ -77,7 +74,6 @@ func (f Float8) Value() (driver.Value, error) {
return f.Float64, nil return f.Float64, nil
} }
// MarshalJSON implements the [encoding/json.Marshaler] interface.
func (f Float8) MarshalJSON() ([]byte, error) { func (f Float8) MarshalJSON() ([]byte, error) {
if !f.Valid { if !f.Valid {
return []byte("null"), nil return []byte("null"), nil
@ -85,7 +81,6 @@ func (f Float8) MarshalJSON() ([]byte, error) {
return json.Marshal(f.Float64) return json.Marshal(f.Float64)
} }
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
func (f *Float8) UnmarshalJSON(b []byte) error { func (f *Float8) UnmarshalJSON(b []byte) error {
var n *float64 var n *float64
err := json.Unmarshal(b, &n) err := json.Unmarshal(b, &n)
@ -213,6 +208,7 @@ func (encodePlanTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, e
} }
func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -22,18 +22,16 @@ type HstoreValuer interface {
// associated with its keys. // associated with its keys.
type Hstore map[string]*string type Hstore map[string]*string
// ScanHstore implements the [HstoreScanner] interface.
func (h *Hstore) ScanHstore(v Hstore) error { func (h *Hstore) ScanHstore(v Hstore) error {
*h = v *h = v
return nil return nil
} }
// HstoreValue implements the [HstoreValuer] interface.
func (h Hstore) HstoreValue() (Hstore, error) { func (h Hstore) HstoreValue() (Hstore, error) {
return h, nil return h, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (h *Hstore) Scan(src any) error { func (h *Hstore) Scan(src any) error {
if src == nil { if src == nil {
*h = nil *h = nil
@ -48,7 +46,7 @@ func (h *Hstore) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (h Hstore) Value() (driver.Value, error) { func (h Hstore) Value() (driver.Value, error) {
if h == nil { if h == nil {
return nil, nil return nil, nil
@ -164,6 +162,7 @@ func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, e
} }
func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {
@ -299,7 +298,7 @@ func (p *hstoreParser) consume() (b byte, end bool) {
return b, false return b, false
} }
func unexpectedByteErr(actualB, expectedB byte) error { func unexpectedByteErr(actualB byte, expectedB byte) error {
return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB) return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB)
} }
@ -317,7 +316,7 @@ func (p *hstoreParser) consumeExpectedByte(expectedB byte) error {
// consumeExpected2 consumes two expected bytes or returns an error. // consumeExpected2 consumes two expected bytes or returns an error.
// This was a bit faster than using a string argument (better inlining? Not sure). // This was a bit faster than using a string argument (better inlining? Not sure).
func (p *hstoreParser) consumeExpected2(one, two byte) error { func (p *hstoreParser) consumeExpected2(one byte, two byte) error {
if p.pos+2 > len(p.str) { if p.pos+2 > len(p.str) {
return errors.New("unexpected end of string") return errors.New("unexpected end of string")
} }

View File

@ -306,13 +306,12 @@ func TestRoundTrip(t *testing.T) {
}) })
} }
} }
} }
func BenchmarkHstoreEncode(b *testing.B) { func BenchmarkHstoreEncode(b *testing.B) {
h := pgtype.Hstore{ h := pgtype.Hstore{"a x": stringPtr("100"), "b": stringPtr("200"), "c": stringPtr("300"),
"a x": stringPtr("100"), "b": stringPtr("200"), "c": stringPtr("300"), "d": stringPtr("400"), "e": stringPtr("500")}
"d": stringPtr("400"), "e": stringPtr("500"),
}
serializeConfigs := []struct { serializeConfigs := []struct {
name string name string

View File

@ -24,7 +24,7 @@ type NetipPrefixValuer interface {
NetipPrefixValue() (netip.Prefix, error) NetipPrefixValue() (netip.Prefix, error)
} }
// InetCodec handles both inet and cidr PostgreSQL types. The preferred Go types are [netip.Prefix] and [netip.Addr]. If // InetCodec handles both inet and cidr PostgreSQL types. The preferred Go types are netip.Prefix and netip.Addr. If
// IsValid() is false then they are treated as SQL NULL. // IsValid() is false then they are treated as SQL NULL.
type InetCodec struct{} type InetCodec struct{}
@ -107,6 +107,7 @@ func (encodePlanInetCodecText) Encode(value any, buf []byte) (newBuf []byte, err
} }
func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -26,7 +26,7 @@ type Int2 struct {
Valid bool Valid bool
} }
// ScanInt64 implements the [Int64Scanner] interface. // ScanInt64 implements the Int64Scanner interface.
func (dst *Int2) ScanInt64(n Int8) error { func (dst *Int2) ScanInt64(n Int8) error {
if !n.Valid { if !n.Valid {
*dst = Int2{} *dst = Int2{}
@ -44,12 +44,11 @@ func (dst *Int2) ScanInt64(n Int8) error {
return nil return nil
} }
// Int64Value implements the [Int64Valuer] interface.
func (n Int2) Int64Value() (Int8, error) { func (n Int2) Int64Value() (Int8, error) {
return Int8{Int64: int64(n.Int16), Valid: n.Valid}, nil return Int8{Int64: int64(n.Int16), Valid: n.Valid}, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *Int2) Scan(src any) error { func (dst *Int2) Scan(src any) error {
if src == nil { if src == nil {
*dst = Int2{} *dst = Int2{}
@ -88,7 +87,7 @@ func (dst *Int2) Scan(src any) error {
return nil return nil
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Int2) Value() (driver.Value, error) { func (src Int2) Value() (driver.Value, error) {
if !src.Valid { if !src.Valid {
return nil, nil return nil, nil
@ -96,7 +95,6 @@ func (src Int2) Value() (driver.Value, error) {
return int64(src.Int16), nil return int64(src.Int16), nil
} }
// MarshalJSON implements the [encoding/json.Marshaler] interface.
func (src Int2) MarshalJSON() ([]byte, error) { func (src Int2) MarshalJSON() ([]byte, error) {
if !src.Valid { if !src.Valid {
return []byte("null"), nil return []byte("null"), nil
@ -104,7 +102,6 @@ func (src Int2) MarshalJSON() ([]byte, error) {
return []byte(strconv.FormatInt(int64(src.Int16), 10)), nil return []byte(strconv.FormatInt(int64(src.Int16), 10)), nil
} }
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
func (dst *Int2) UnmarshalJSON(b []byte) error { func (dst *Int2) UnmarshalJSON(b []byte) error {
var n *int16 var n *int16
err := json.Unmarshal(b, &n) err := json.Unmarshal(b, &n)
@ -589,7 +586,7 @@ type Int4 struct {
Valid bool Valid bool
} }
// ScanInt64 implements the [Int64Scanner] interface. // ScanInt64 implements the Int64Scanner interface.
func (dst *Int4) ScanInt64(n Int8) error { func (dst *Int4) ScanInt64(n Int8) error {
if !n.Valid { if !n.Valid {
*dst = Int4{} *dst = Int4{}
@ -607,12 +604,11 @@ func (dst *Int4) ScanInt64(n Int8) error {
return nil return nil
} }
// Int64Value implements the [Int64Valuer] interface.
func (n Int4) Int64Value() (Int8, error) { func (n Int4) Int64Value() (Int8, error) {
return Int8{Int64: int64(n.Int32), Valid: n.Valid}, nil return Int8{Int64: int64(n.Int32), Valid: n.Valid}, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *Int4) Scan(src any) error { func (dst *Int4) Scan(src any) error {
if src == nil { if src == nil {
*dst = Int4{} *dst = Int4{}
@ -651,7 +647,7 @@ func (dst *Int4) Scan(src any) error {
return nil return nil
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Int4) Value() (driver.Value, error) { func (src Int4) Value() (driver.Value, error) {
if !src.Valid { if !src.Valid {
return nil, nil return nil, nil
@ -659,7 +655,6 @@ func (src Int4) Value() (driver.Value, error) {
return int64(src.Int32), nil return int64(src.Int32), nil
} }
// MarshalJSON implements the [encoding/json.Marshaler] interface.
func (src Int4) MarshalJSON() ([]byte, error) { func (src Int4) MarshalJSON() ([]byte, error) {
if !src.Valid { if !src.Valid {
return []byte("null"), nil return []byte("null"), nil
@ -667,7 +662,6 @@ func (src Int4) MarshalJSON() ([]byte, error) {
return []byte(strconv.FormatInt(int64(src.Int32), 10)), nil return []byte(strconv.FormatInt(int64(src.Int32), 10)), nil
} }
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
func (dst *Int4) UnmarshalJSON(b []byte) error { func (dst *Int4) UnmarshalJSON(b []byte) error {
var n *int32 var n *int32
err := json.Unmarshal(b, &n) err := json.Unmarshal(b, &n)
@ -1163,7 +1157,7 @@ type Int8 struct {
Valid bool Valid bool
} }
// ScanInt64 implements the [Int64Scanner] interface. // ScanInt64 implements the Int64Scanner interface.
func (dst *Int8) ScanInt64(n Int8) error { func (dst *Int8) ScanInt64(n Int8) error {
if !n.Valid { if !n.Valid {
*dst = Int8{} *dst = Int8{}
@ -1181,12 +1175,11 @@ func (dst *Int8) ScanInt64(n Int8) error {
return nil return nil
} }
// Int64Value implements the [Int64Valuer] interface.
func (n Int8) Int64Value() (Int8, error) { func (n Int8) Int64Value() (Int8, error) {
return Int8{Int64: int64(n.Int64), Valid: n.Valid}, nil return Int8{Int64: int64(n.Int64), Valid: n.Valid}, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *Int8) Scan(src any) error { func (dst *Int8) Scan(src any) error {
if src == nil { if src == nil {
*dst = Int8{} *dst = Int8{}
@ -1225,7 +1218,7 @@ func (dst *Int8) Scan(src any) error {
return nil return nil
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Int8) Value() (driver.Value, error) { func (src Int8) Value() (driver.Value, error) {
if !src.Valid { if !src.Valid {
return nil, nil return nil, nil
@ -1233,7 +1226,6 @@ func (src Int8) Value() (driver.Value, error) {
return int64(src.Int64), nil return int64(src.Int64), nil
} }
// MarshalJSON implements the [encoding/json.Marshaler] interface.
func (src Int8) MarshalJSON() ([]byte, error) { func (src Int8) MarshalJSON() ([]byte, error) {
if !src.Valid { if !src.Valid {
return []byte("null"), nil return []byte("null"), nil
@ -1241,7 +1233,6 @@ func (src Int8) MarshalJSON() ([]byte, error) {
return []byte(strconv.FormatInt(int64(src.Int64), 10)), nil return []byte(strconv.FormatInt(int64(src.Int64), 10)), nil
} }
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
func (dst *Int8) UnmarshalJSON(b []byte) error { func (dst *Int8) UnmarshalJSON(b []byte) error {
var n *int64 var n *int64
err := json.Unmarshal(b, &n) err := json.Unmarshal(b, &n)

View File

@ -27,7 +27,7 @@ type Int<%= pg_byte_size %> struct {
Valid bool Valid bool
} }
// ScanInt64 implements the [Int64Scanner] interface. // ScanInt64 implements the Int64Scanner interface.
func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error { func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error {
if !n.Valid { if !n.Valid {
*dst = Int<%= pg_byte_size %>{} *dst = Int<%= pg_byte_size %>{}
@ -45,12 +45,11 @@ func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error {
return nil return nil
} }
// Int64Value implements the [Int64Valuer] interface.
func (n Int<%= pg_byte_size %>) Int64Value() (Int8, error) { func (n Int<%= pg_byte_size %>) Int64Value() (Int8, error) {
return Int8{Int64: int64(n.Int<%= pg_bit_size %>), Valid: n.Valid}, nil return Int8{Int64: int64(n.Int<%= pg_bit_size %>), Valid: n.Valid}, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *Int<%= pg_byte_size %>) Scan(src any) error { func (dst *Int<%= pg_byte_size %>) Scan(src any) error {
if src == nil { if src == nil {
*dst = Int<%= pg_byte_size %>{} *dst = Int<%= pg_byte_size %>{}
@ -89,7 +88,7 @@ func (dst *Int<%= pg_byte_size %>) Scan(src any) error {
return nil return nil
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) {
if !src.Valid { if !src.Valid {
return nil, nil return nil, nil
@ -97,7 +96,6 @@ func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) {
return int64(src.Int<%= pg_bit_size %>), nil return int64(src.Int<%= pg_bit_size %>), nil
} }
// MarshalJSON implements the [encoding/json.Marshaler] interface.
func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) { func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) {
if !src.Valid { if !src.Valid {
return []byte("null"), nil return []byte("null"), nil
@ -105,7 +103,6 @@ func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) {
return []byte(strconv.FormatInt(int64(src.Int<%= pg_bit_size %>), 10)), nil return []byte(strconv.FormatInt(int64(src.Int<%= pg_bit_size %>), 10)), nil
} }
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
func (dst *Int<%= pg_byte_size %>) UnmarshalJSON(b []byte) error { func (dst *Int<%= pg_byte_size %>) UnmarshalJSON(b []byte) error {
var n *int<%= pg_bit_size %> var n *int<%= pg_bit_size %>
err := json.Unmarshal(b, &n) err := json.Unmarshal(b, &n)

View File

@ -33,18 +33,16 @@ type Interval struct {
Valid bool Valid bool
} }
// ScanInterval implements the [IntervalScanner] interface.
func (interval *Interval) ScanInterval(v Interval) error { func (interval *Interval) ScanInterval(v Interval) error {
*interval = v *interval = v
return nil return nil
} }
// IntervalValue implements the [IntervalValuer] interface.
func (interval Interval) IntervalValue() (Interval, error) { func (interval Interval) IntervalValue() (Interval, error) {
return interval, nil return interval, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (interval *Interval) Scan(src any) error { func (interval *Interval) Scan(src any) error {
if src == nil { if src == nil {
*interval = Interval{} *interval = Interval{}
@ -59,7 +57,7 @@ func (interval *Interval) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (interval Interval) Value() (driver.Value, error) { func (interval Interval) Value() (driver.Value, error) {
if !interval.Valid { if !interval.Valid {
return nil, nil return nil, nil
@ -159,6 +157,7 @@ func (encodePlanIntervalCodecText) Encode(value any, buf []byte) (newBuf []byte,
} }
func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -95,8 +95,7 @@ func TestJSONBCodecCustomMarshal(t *testing.T) {
Unmarshal: func(data []byte, v any) error { Unmarshal: func(data []byte, v any) error {
return json.Unmarshal([]byte(`{"custom":"value"}`), v) return json.Unmarshal([]byte(`{"custom":"value"}`), v)
}, },
}, }})
})
} }
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{ pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{

View File

@ -24,13 +24,11 @@ type Line struct {
Valid bool Valid bool
} }
// ScanLine implements the [LineScanner] interface.
func (line *Line) ScanLine(v Line) error { func (line *Line) ScanLine(v Line) error {
*line = v *line = v
return nil return nil
} }
// LineValue implements the [LineValuer] interface.
func (line Line) LineValue() (Line, error) { func (line Line) LineValue() (Line, error) {
return line, nil return line, nil
} }
@ -39,7 +37,7 @@ func (line *Line) Set(src any) error {
return fmt.Errorf("cannot convert %v to Line", src) return fmt.Errorf("cannot convert %v to Line", src)
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (line *Line) Scan(src any) error { func (line *Line) Scan(src any) error {
if src == nil { if src == nil {
*line = Line{} *line = Line{}
@ -54,7 +52,7 @@ func (line *Line) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (line Line) Value() (driver.Value, error) { func (line Line) Value() (driver.Value, error) {
if !line.Valid { if !line.Valid {
return nil, nil return nil, nil
@ -131,6 +129,7 @@ func (encodePlanLineCodecText) Encode(value any, buf []byte) (newBuf []byte, err
} }
func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -24,18 +24,16 @@ type Lseg struct {
Valid bool Valid bool
} }
// ScanLseg implements the [LsegScanner] interface.
func (lseg *Lseg) ScanLseg(v Lseg) error { func (lseg *Lseg) ScanLseg(v Lseg) error {
*lseg = v *lseg = v
return nil return nil
} }
// LsegValue implements the [LsegValuer] interface.
func (lseg Lseg) LsegValue() (Lseg, error) { func (lseg Lseg) LsegValue() (Lseg, error) {
return lseg, nil return lseg, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (lseg *Lseg) Scan(src any) error { func (lseg *Lseg) Scan(src any) error {
if src == nil { if src == nil {
*lseg = Lseg{} *lseg = Lseg{}
@ -50,7 +48,7 @@ func (lseg *Lseg) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (lseg Lseg) Value() (driver.Value, error) { func (lseg Lseg) Value() (driver.Value, error) {
if !lseg.Valid { if !lseg.Valid {
return nil, nil return nil, nil
@ -129,6 +127,7 @@ func (encodePlanLsegCodecText) Encode(value any, buf []byte) (newBuf []byte, err
} }
func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -374,6 +374,7 @@ parseValueLoop:
} }
return elements, nil return elements, nil
} }
func parseRange(buf *bytes.Buffer) (string, error) { func parseRange(buf *bytes.Buffer) (string, error) {
@ -402,8 +403,8 @@ func parseRange(buf *bytes.Buffer) (string, error) {
// Multirange is a generic multirange type. // Multirange is a generic multirange type.
// //
// T should implement [RangeValuer] and *T should implement [RangeScanner]. However, there does not appear to be a way to // T should implement RangeValuer and *T should implement RangeScanner. However, there does not appear to be a way to
// enforce the [RangeScanner] constraint. // enforce the RangeScanner constraint.
type Multirange[T RangeValuer] []T type Multirange[T RangeValuer] []T
func (r Multirange[T]) IsNull() bool { func (r Multirange[T]) IsNull() bool {

View File

@ -71,6 +71,7 @@ func TestMultirangeCodecDecodeValue(t *testing.T) {
skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)")
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) {
for _, tt := range []struct { for _, tt := range []struct {
sql string sql string
expected any expected any

View File

@ -27,20 +27,16 @@ const (
pgNumericNegInfSign = 0xf000 pgNumericNegInfSign = 0xf000
) )
var ( var big0 *big.Int = big.NewInt(0)
big0 *big.Int = big.NewInt(0) var big1 *big.Int = big.NewInt(1)
big1 *big.Int = big.NewInt(1) var big10 *big.Int = big.NewInt(10)
big10 *big.Int = big.NewInt(10) var big100 *big.Int = big.NewInt(100)
big100 *big.Int = big.NewInt(100) var big1000 *big.Int = big.NewInt(1000)
big1000 *big.Int = big.NewInt(1000)
)
var ( var bigNBase *big.Int = big.NewInt(nbase)
bigNBase *big.Int = big.NewInt(nbase) var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase)
bigNBaseX2 *big.Int = big.NewInt(nbase * nbase) var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase)
bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase) var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase)
bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase)
)
type NumericScanner interface { type NumericScanner interface {
ScanNumeric(v Numeric) error ScanNumeric(v Numeric) error
@ -58,18 +54,15 @@ type Numeric struct {
Valid bool Valid bool
} }
// ScanNumeric implements the [NumericScanner] interface.
func (n *Numeric) ScanNumeric(v Numeric) error { func (n *Numeric) ScanNumeric(v Numeric) error {
*n = v *n = v
return nil return nil
} }
// NumericValue implements the [NumericValuer] interface.
func (n Numeric) NumericValue() (Numeric, error) { func (n Numeric) NumericValue() (Numeric, error) {
return n, nil return n, nil
} }
// Float64Value implements the [Float64Valuer] interface.
func (n Numeric) Float64Value() (Float8, error) { func (n Numeric) Float64Value() (Float8, error) {
if !n.Valid { if !n.Valid {
return Float8{}, nil return Float8{}, nil
@ -99,7 +92,6 @@ func (n Numeric) Float64Value() (Float8, error) {
return Float8{Float64: f, Valid: true}, nil return Float8{Float64: f, Valid: true}, nil
} }
// ScanInt64 implements the [Int64Scanner] interface.
func (n *Numeric) ScanInt64(v Int8) error { func (n *Numeric) ScanInt64(v Int8) error {
if !v.Valid { if !v.Valid {
*n = Numeric{} *n = Numeric{}
@ -110,7 +102,6 @@ func (n *Numeric) ScanInt64(v Int8) error {
return nil return nil
} }
// Int64Value implements the [Int64Valuer] interface.
func (n Numeric) Int64Value() (Int8, error) { func (n Numeric) Int64Value() (Int8, error) {
if !n.Valid { if !n.Valid {
return Int8{}, nil return Int8{}, nil
@ -212,7 +203,7 @@ func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) {
return accum, rp, digits return accum, rp, digits
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (n *Numeric) Scan(src any) error { func (n *Numeric) Scan(src any) error {
if src == nil { if src == nil {
*n = Numeric{} *n = Numeric{}
@ -227,7 +218,7 @@ func (n *Numeric) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (n Numeric) Value() (driver.Value, error) { func (n Numeric) Value() (driver.Value, error) {
if !n.Valid { if !n.Valid {
return nil, nil return nil, nil
@ -240,7 +231,6 @@ func (n Numeric) Value() (driver.Value, error) {
return string(buf), err return string(buf), err
} }
// MarshalJSON implements the [encoding/json.Marshaler] interface.
func (n Numeric) MarshalJSON() ([]byte, error) { func (n Numeric) MarshalJSON() ([]byte, error) {
if !n.Valid { if !n.Valid {
return []byte("null"), nil return []byte("null"), nil
@ -253,7 +243,6 @@ func (n Numeric) MarshalJSON() ([]byte, error) {
return n.numberTextBytes(), nil return n.numberTextBytes(), nil
} }
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
func (n *Numeric) UnmarshalJSON(src []byte) error { func (n *Numeric) UnmarshalJSON(src []byte) error {
if bytes.Equal(src, []byte(`null`)) { if bytes.Equal(src, []byte(`null`)) {
*n = Numeric{} *n = Numeric{}
@ -564,6 +553,7 @@ func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) {
} }
func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -198,6 +198,7 @@ func TestNumericMarshalJSON(t *testing.T) {
skipCockroachDB(t, "server formats numeric text format differently") skipCockroachDB(t, "server formats numeric text format differently")
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
for i, tt := range []struct { for i, tt := range []struct {
decString string decString string
}{ }{

View File

@ -25,18 +25,16 @@ type Path struct {
Valid bool Valid bool
} }
// ScanPath implements the [PathScanner] interface.
func (path *Path) ScanPath(v Path) error { func (path *Path) ScanPath(v Path) error {
*path = v *path = v
return nil return nil
} }
// PathValue implements the [PathValuer] interface.
func (path Path) PathValue() (Path, error) { func (path Path) PathValue() (Path, error) {
return path, nil return path, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (path *Path) Scan(src any) error { func (path *Path) Scan(src any) error {
if src == nil { if src == nil {
*path = Path{} *path = Path{}
@ -51,7 +49,7 @@ func (path *Path) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (path Path) Value() (driver.Value, error) { func (path Path) Value() (driver.Value, error) {
if !path.Valid { if !path.Valid {
return nil, nil return nil, nil
@ -156,6 +154,7 @@ func (encodePlanPathCodecText) Encode(value any, buf []byte) (newBuf []byte, err
} }
func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -2010,7 +2010,7 @@ var valuerReflectType = reflect.TypeFor[driver.Valuer]()
// isNilDriverValuer returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement // isNilDriverValuer returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement
// driver.Valuer if it is only implemented by T. // driver.Valuer if it is only implemented by T.
func isNilDriverValuer(value any) (isNil, callNilDriverValuer bool) { func isNilDriverValuer(value any) (isNil bool, callNilDriverValuer bool) {
if value == nil { if value == nil {
return true, false return true, false
} }

View File

@ -34,19 +34,17 @@ func init() {
} }
// Test for renamed types // Test for renamed types
type ( type _string string
_string string type _bool bool
_bool bool type _uint8 uint8
_uint8 uint8 type _int8 int8
_int8 int8 type _int16 int16
_int16 int16 type _int16Slice []int16
_int16Slice []int16 type _int32Slice []int32
_int32Slice []int32 type _int64Slice []int64
_int64Slice []int64 type _float32Slice []float32
_float32Slice []float32 type _float64Slice []float64
_float64Slice []float64 type _byteSlice []byte
_byteSlice []byte
)
// unregisteredOID represents an actual type that is not registered. Cannot use 0 because that represents that the type // unregisteredOID represents an actual type that is not registered. Cannot use 0 because that represents that the type
// is not known (e.g. when using the simple protocol). // is not known (e.g. when using the simple protocol).
@ -532,8 +530,7 @@ func TestMapEncodePlanCacheUUIDTypeConfusion(t *testing.T) {
0, 0, 0, 16, 0, 0, 0, 16,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
0, 0, 0, 16, 0, 0, 0, 16,
15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
}
m := pgtype.NewMap() m := pgtype.NewMap()
buf, err := m.Encode(pgtype.UUIDArrayOID, pgtype.BinaryFormatCode, buf, err := m.Encode(pgtype.UUIDArrayOID, pgtype.BinaryFormatCode,

View File

@ -30,13 +30,11 @@ type Point struct {
Valid bool Valid bool
} }
// ScanPoint implements the [PointScanner] interface.
func (p *Point) ScanPoint(v Point) error { func (p *Point) ScanPoint(v Point) error {
*p = v *p = v
return nil return nil
} }
// PointValue implements the [PointValuer] interface.
func (p Point) PointValue() (Point, error) { func (p Point) PointValue() (Point, error) {
return p, nil return p, nil
} }
@ -70,7 +68,7 @@ func parsePoint(src []byte) (*Point, error) {
return &Point{P: Vec2{x, y}, Valid: true}, nil return &Point{P: Vec2{x, y}, Valid: true}, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *Point) Scan(src any) error { func (dst *Point) Scan(src any) error {
if src == nil { if src == nil {
*dst = Point{} *dst = Point{}
@ -85,7 +83,7 @@ func (dst *Point) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Point) Value() (driver.Value, error) { func (src Point) Value() (driver.Value, error) {
if !src.Valid { if !src.Valid {
return nil, nil return nil, nil
@ -98,7 +96,6 @@ func (src Point) Value() (driver.Value, error) {
return string(buf), err return string(buf), err
} }
// MarshalJSON implements the [encoding/json.Marshaler] interface.
func (src Point) MarshalJSON() ([]byte, error) { func (src Point) MarshalJSON() ([]byte, error) {
if !src.Valid { if !src.Valid {
return []byte("null"), nil return []byte("null"), nil
@ -111,7 +108,6 @@ func (src Point) MarshalJSON() ([]byte, error) {
return buff.Bytes(), nil return buff.Bytes(), nil
} }
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
func (dst *Point) UnmarshalJSON(point []byte) error { func (dst *Point) UnmarshalJSON(point []byte) error {
p, err := parsePoint(point) p, err := parsePoint(point)
if err != nil { if err != nil {
@ -182,6 +178,7 @@ func (encodePlanPointCodecText) Encode(value any, buf []byte) (newBuf []byte, er
} }
func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -24,18 +24,16 @@ type Polygon struct {
Valid bool Valid bool
} }
// ScanPolygon implements the [PolygonScanner] interface.
func (p *Polygon) ScanPolygon(v Polygon) error { func (p *Polygon) ScanPolygon(v Polygon) error {
*p = v *p = v
return nil return nil
} }
// PolygonValue implements the [PolygonValuer] interface.
func (p Polygon) PolygonValue() (Polygon, error) { func (p Polygon) PolygonValue() (Polygon, error) {
return p, nil return p, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (p *Polygon) Scan(src any) error { func (p *Polygon) Scan(src any) error {
if src == nil { if src == nil {
*p = Polygon{} *p = Polygon{}
@ -50,7 +48,7 @@ func (p *Polygon) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (p Polygon) Value() (driver.Value, error) { func (p Polygon) Value() (driver.Value, error) {
if !p.Valid { if !p.Valid {
return nil, nil return nil, nil
@ -141,6 +139,7 @@ func (encodePlanPolygonCodecText) Encode(value any, buf []byte) (newBuf []byte,
} }
func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -191,13 +191,11 @@ type untypedBinaryRange struct {
// 18 = [ = 10010 // 18 = [ = 10010
// 24 = = 11000 // 24 = = 11000
const ( const emptyMask = 1
emptyMask = 1 const lowerInclusiveMask = 2
lowerInclusiveMask = 2 const upperInclusiveMask = 4
upperInclusiveMask = 4 const lowerUnboundedMask = 8
lowerUnboundedMask = 8 const upperUnboundedMask = 16
upperUnboundedMask = 16
)
func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) { func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) {
ubr := &untypedBinaryRange{} ubr := &untypedBinaryRange{}
@ -275,6 +273,7 @@ func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) {
} }
return ubr, nil return ubr, nil
} }
// Range is a generic range type. // Range is a generic range type.

View File

@ -75,6 +75,7 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) {
skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)")
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
var r pgtype.Range[pgtype.Int4] var r pgtype.Range[pgtype.Int4]
err := conn.QueryRow(context.Background(), `select '[1,5)'::int4range`).Scan(&r) err := conn.QueryRow(context.Background(), `select '[1,5)'::int4range`).Scan(&r)
@ -128,6 +129,7 @@ func TestRangeCodecDecodeValue(t *testing.T) {
skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)") skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)")
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) { defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) {
for _, tt := range []struct { for _, tt := range []struct {
sql string sql string
expected any expected any

View File

@ -121,4 +121,5 @@ func (RecordCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (an
default: default:
return nil, fmt.Errorf("unknown format code %d", format) return nil, fmt.Errorf("unknown format code %d", format)
} }
} }

View File

@ -19,18 +19,16 @@ type Text struct {
Valid bool Valid bool
} }
// ScanText implements the [TextScanner] interface.
func (t *Text) ScanText(v Text) error { func (t *Text) ScanText(v Text) error {
*t = v *t = v
return nil return nil
} }
// TextValue implements the [TextValuer] interface.
func (t Text) TextValue() (Text, error) { func (t Text) TextValue() (Text, error) {
return t, nil return t, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *Text) Scan(src any) error { func (dst *Text) Scan(src any) error {
if src == nil { if src == nil {
*dst = Text{} *dst = Text{}
@ -49,7 +47,7 @@ func (dst *Text) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Text) Value() (driver.Value, error) { func (src Text) Value() (driver.Value, error) {
if !src.Valid { if !src.Valid {
return nil, nil return nil, nil
@ -57,7 +55,6 @@ func (src Text) Value() (driver.Value, error) {
return src.String, nil return src.String, nil
} }
// MarshalJSON implements the [encoding/json.Marshaler] interface.
func (src Text) MarshalJSON() ([]byte, error) { func (src Text) MarshalJSON() ([]byte, error) {
if !src.Valid { if !src.Valid {
return []byte("null"), nil return []byte("null"), nil
@ -66,7 +63,6 @@ func (src Text) MarshalJSON() ([]byte, error) {
return json.Marshal(src.String) return json.Marshal(src.String)
} }
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
func (dst *Text) UnmarshalJSON(b []byte) error { func (dst *Text) UnmarshalJSON(b []byte) error {
var s *string var s *string
err := json.Unmarshal(b, &s) err := json.Unmarshal(b, &s)
@ -150,6 +146,7 @@ func (encodePlanTextCodecTextValuer) Encode(value any, buf []byte) (newBuf []byt
} }
func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case TextFormatCode, BinaryFormatCode: case TextFormatCode, BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -35,18 +35,16 @@ type TID struct {
Valid bool Valid bool
} }
// ScanTID implements the [TIDScanner] interface.
func (b *TID) ScanTID(v TID) error { func (b *TID) ScanTID(v TID) error {
*b = v *b = v
return nil return nil
} }
// TIDValue implements the [TIDValuer] interface.
func (b TID) TIDValue() (TID, error) { func (b TID) TIDValue() (TID, error) {
return b, nil return b, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *TID) Scan(src any) error { func (dst *TID) Scan(src any) error {
if src == nil { if src == nil {
*dst = TID{} *dst = TID{}
@ -61,7 +59,7 @@ func (dst *TID) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src TID) Value() (driver.Value, error) { func (src TID) Value() (driver.Value, error) {
if !src.Valid { if !src.Valid {
return nil, nil return nil, nil
@ -133,6 +131,7 @@ func (encodePlanTIDCodecText) Encode(value any, buf []byte) (newBuf []byte, err
} }
func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -29,18 +29,16 @@ type Time struct {
Valid bool Valid bool
} }
// ScanTime implements the [TimeScanner] interface.
func (t *Time) ScanTime(v Time) error { func (t *Time) ScanTime(v Time) error {
*t = v *t = v
return nil return nil
} }
// TimeValue implements the [TimeValuer] interface.
func (t Time) TimeValue() (Time, error) { func (t Time) TimeValue() (Time, error) {
return t, nil return t, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (t *Time) Scan(src any) error { func (t *Time) Scan(src any) error {
if src == nil { if src == nil {
*t = Time{} *t = Time{}
@ -60,7 +58,7 @@ func (t *Time) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (t Time) Value() (driver.Value, error) { func (t Time) Value() (driver.Value, error) {
if !t.Valid { if !t.Valid {
return nil, nil return nil, nil
@ -139,6 +137,7 @@ func (encodePlanTimeCodecText) Encode(value any, buf []byte) (newBuf []byte, err
} }
func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -11,10 +11,8 @@ import (
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
const ( const pgTimestampFormat = "2006-01-02 15:04:05.999999999"
pgTimestampFormat = "2006-01-02 15:04:05.999999999" const jsonISO8601 = "2006-01-02T15:04:05.999999999"
jsonISO8601 = "2006-01-02T15:04:05.999999999"
)
type TimestampScanner interface { type TimestampScanner interface {
ScanTimestamp(v Timestamp) error ScanTimestamp(v Timestamp) error
@ -31,18 +29,16 @@ type Timestamp struct {
Valid bool Valid bool
} }
// ScanTimestamp implements the [TimestampScanner] interface.
func (ts *Timestamp) ScanTimestamp(v Timestamp) error { func (ts *Timestamp) ScanTimestamp(v Timestamp) error {
*ts = v *ts = v
return nil return nil
} }
// TimestampValue implements the [TimestampValuer] interface.
func (ts Timestamp) TimestampValue() (Timestamp, error) { func (ts Timestamp) TimestampValue() (Timestamp, error) {
return ts, nil return ts, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (ts *Timestamp) Scan(src any) error { func (ts *Timestamp) Scan(src any) error {
if src == nil { if src == nil {
*ts = Timestamp{} *ts = Timestamp{}
@ -60,7 +56,7 @@ func (ts *Timestamp) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (ts Timestamp) Value() (driver.Value, error) { func (ts Timestamp) Value() (driver.Value, error) {
if !ts.Valid { if !ts.Valid {
return nil, nil return nil, nil
@ -72,7 +68,6 @@ func (ts Timestamp) Value() (driver.Value, error) {
return ts.Time, nil return ts.Time, nil
} }
// MarshalJSON implements the [encoding/json.Marshaler] interface.
func (ts Timestamp) MarshalJSON() ([]byte, error) { func (ts Timestamp) MarshalJSON() ([]byte, error) {
if !ts.Valid { if !ts.Valid {
return []byte("null"), nil return []byte("null"), nil
@ -92,7 +87,6 @@ func (ts Timestamp) MarshalJSON() ([]byte, error) {
return json.Marshal(s) return json.Marshal(s)
} }
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
func (ts *Timestamp) UnmarshalJSON(b []byte) error { func (ts *Timestamp) UnmarshalJSON(b []byte) error {
var s *string var s *string
err := json.Unmarshal(b, &s) err := json.Unmarshal(b, &s)

View File

@ -102,6 +102,7 @@ func TestTimestampCodecDecodeTextInvalid(t *testing.T) {
} }
func TestTimestampMarshalJSON(t *testing.T) { func TestTimestampMarshalJSON(t *testing.T) {
tsStruct := struct { tsStruct := struct {
TS pgtype.Timestamp `json:"ts"` TS pgtype.Timestamp `json:"ts"`
}{} }{}

View File

@ -11,12 +11,10 @@ import (
"github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/internal/pgio"
) )
const ( const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07"
pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07" const pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00"
pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00" const pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00"
pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00" const microsecFromUnixEpochToY2K = 946684800 * 1000000
microsecFromUnixEpochToY2K = 946684800 * 1000000
)
const ( const (
negativeInfinityMicrosecondOffset = -9223372036854775808 negativeInfinityMicrosecondOffset = -9223372036854775808
@ -38,18 +36,16 @@ type Timestamptz struct {
Valid bool Valid bool
} }
// ScanTimestamptz implements the [TimestamptzScanner] interface.
func (tstz *Timestamptz) ScanTimestamptz(v Timestamptz) error { func (tstz *Timestamptz) ScanTimestamptz(v Timestamptz) error {
*tstz = v *tstz = v
return nil return nil
} }
// TimestamptzValue implements the [TimestamptzValuer] interface.
func (tstz Timestamptz) TimestamptzValue() (Timestamptz, error) { func (tstz Timestamptz) TimestamptzValue() (Timestamptz, error) {
return tstz, nil return tstz, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (tstz *Timestamptz) Scan(src any) error { func (tstz *Timestamptz) Scan(src any) error {
if src == nil { if src == nil {
*tstz = Timestamptz{} *tstz = Timestamptz{}
@ -67,7 +63,7 @@ func (tstz *Timestamptz) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (tstz Timestamptz) Value() (driver.Value, error) { func (tstz Timestamptz) Value() (driver.Value, error) {
if !tstz.Valid { if !tstz.Valid {
return nil, nil return nil, nil
@ -79,7 +75,6 @@ func (tstz Timestamptz) Value() (driver.Value, error) {
return tstz.Time, nil return tstz.Time, nil
} }
// MarshalJSON implements the [encoding/json.Marshaler] interface.
func (tstz Timestamptz) MarshalJSON() ([]byte, error) { func (tstz Timestamptz) MarshalJSON() ([]byte, error) {
if !tstz.Valid { if !tstz.Valid {
return []byte("null"), nil return []byte("null"), nil
@ -99,7 +94,6 @@ func (tstz Timestamptz) MarshalJSON() ([]byte, error) {
return json.Marshal(s) return json.Marshal(s)
} }
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
func (tstz *Timestamptz) UnmarshalJSON(b []byte) error { func (tstz *Timestamptz) UnmarshalJSON(b []byte) error {
var s *string var s *string
err := json.Unmarshal(b, &s) err := json.Unmarshal(b, &s)
@ -231,6 +225,7 @@ func (encodePlanTimestamptzCodecText) Encode(value any, buf []byte) (newBuf []by
} }
func (c *TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (c *TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -3,7 +3,6 @@ package pgtype
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/binary" "encoding/binary"
"encoding/json"
"fmt" "fmt"
"math" "math"
"strconv" "strconv"
@ -25,18 +24,16 @@ type Uint32 struct {
Valid bool Valid bool
} }
// ScanUint32 implements the [Uint32Scanner] interface.
func (n *Uint32) ScanUint32(v Uint32) error { func (n *Uint32) ScanUint32(v Uint32) error {
*n = v *n = v
return nil return nil
} }
// Uint32Value implements the [Uint32Valuer] interface.
func (n Uint32) Uint32Value() (Uint32, error) { func (n Uint32) Uint32Value() (Uint32, error) {
return n, nil return n, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *Uint32) Scan(src any) error { func (dst *Uint32) Scan(src any) error {
if src == nil { if src == nil {
*dst = Uint32{} *dst = Uint32{}
@ -70,7 +67,7 @@ func (dst *Uint32) Scan(src any) error {
return nil return nil
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Uint32) Value() (driver.Value, error) { func (src Uint32) Value() (driver.Value, error) {
if !src.Valid { if !src.Valid {
return nil, nil return nil, nil
@ -78,31 +75,6 @@ func (src Uint32) Value() (driver.Value, error) {
return int64(src.Uint32), nil return int64(src.Uint32), nil
} }
// MarshalJSON implements the [encoding/json.Marshaler] interface.
func (src Uint32) MarshalJSON() ([]byte, error) {
if !src.Valid {
return []byte("null"), nil
}
return json.Marshal(src.Uint32)
}
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
func (dst *Uint32) UnmarshalJSON(b []byte) error {
var n *uint32
err := json.Unmarshal(b, &n)
if err != nil {
return err
}
if n == nil {
*dst = Uint32{}
} else {
*dst = Uint32{Uint32: *n, Valid: true}
}
return nil
}
type Uint32Codec struct{} type Uint32Codec struct{}
func (Uint32Codec) FormatSupported(format int16) bool { func (Uint32Codec) FormatSupported(format int16) bool {
@ -225,6 +197,7 @@ func (encodePlanUint32CodecTextInt64Valuer) Encode(value any, buf []byte) (newBu
} }
func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -24,18 +24,16 @@ type Uint64 struct {
Valid bool Valid bool
} }
// ScanUint64 implements the [Uint64Scanner] interface.
func (n *Uint64) ScanUint64(v Uint64) error { func (n *Uint64) ScanUint64(v Uint64) error {
*n = v *n = v
return nil return nil
} }
// Uint64Value implements the [Uint64Valuer] interface.
func (n Uint64) Uint64Value() (Uint64, error) { func (n Uint64) Uint64Value() (Uint64, error) {
return n, nil return n, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *Uint64) Scan(src any) error { func (dst *Uint64) Scan(src any) error {
if src == nil { if src == nil {
*dst = Uint64{} *dst = Uint64{}
@ -65,7 +63,7 @@ func (dst *Uint64) Scan(src any) error {
return nil return nil
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Uint64) Value() (driver.Value, error) { func (src Uint64) Value() (driver.Value, error) {
if !src.Valid { if !src.Valid {
return nil, nil return nil, nil
@ -196,6 +194,7 @@ func (encodePlanUint64CodecTextInt64Valuer) Encode(value any, buf []byte) (newBu
} }
func (Uint64Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { func (Uint64Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format { switch format {
case BinaryFormatCode: case BinaryFormatCode:
switch target.(type) { switch target.(type) {

View File

@ -20,13 +20,11 @@ type UUID struct {
Valid bool Valid bool
} }
// ScanUUID implements the [UUIDScanner] interface.
func (b *UUID) ScanUUID(v UUID) error { func (b *UUID) ScanUUID(v UUID) error {
*b = v *b = v
return nil return nil
} }
// UUIDValue implements the [UUIDValuer] interface.
func (b UUID) UUIDValue() (UUID, error) { func (b UUID) UUIDValue() (UUID, error) {
return b, nil return b, nil
} }
@ -69,7 +67,7 @@ func encodeUUID(src [16]byte) string {
return string(buf[:]) return string(buf[:])
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *UUID) Scan(src any) error { func (dst *UUID) Scan(src any) error {
if src == nil { if src == nil {
*dst = UUID{} *dst = UUID{}
@ -89,7 +87,7 @@ func (dst *UUID) Scan(src any) error {
return fmt.Errorf("cannot scan %T", src) return fmt.Errorf("cannot scan %T", src)
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src UUID) Value() (driver.Value, error) { func (src UUID) Value() (driver.Value, error) {
if !src.Valid { if !src.Valid {
return nil, nil return nil, nil
@ -106,7 +104,6 @@ func (src UUID) String() string {
return encodeUUID(src.Bytes) return encodeUUID(src.Bytes)
} }
// MarshalJSON implements the [encoding/json.Marshaler] interface.
func (src UUID) MarshalJSON() ([]byte, error) { func (src UUID) MarshalJSON() ([]byte, error) {
if !src.Valid { if !src.Valid {
return []byte("null"), nil return []byte("null"), nil
@ -119,7 +116,6 @@ func (src UUID) MarshalJSON() ([]byte, error) {
return buff.Bytes(), nil return buff.Bytes(), nil
} }
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
func (dst *UUID) UnmarshalJSON(src []byte) error { func (dst *UUID) UnmarshalJSON(src []byte) error {
if bytes.Equal(src, []byte("null")) { if bytes.Equal(src, []byte("null")) {
*dst = UUID{} *dst = UUID{}

View File

@ -8,10 +8,9 @@ import (
type Float8 float64 type Float8 float64
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
func (Float8) SkipUnderlyingTypePlan() {} func (Float8) SkipUnderlyingTypePlan() {}
// ScanFloat64 implements the [pgtype.Float64Scanner] interface. // ScanFloat64 implements the Float64Scanner interface.
func (f *Float8) ScanFloat64(n pgtype.Float8) error { func (f *Float8) ScanFloat64(n pgtype.Float8) error {
if !n.Valid { if !n.Valid {
*f = 0 *f = 0
@ -23,7 +22,6 @@ func (f *Float8) ScanFloat64(n pgtype.Float8) error {
return nil return nil
} }
// Float64Value implements the [pgtype.Float64Valuer] interface.
func (f Float8) Float64Value() (pgtype.Float8, error) { func (f Float8) Float64Value() (pgtype.Float8, error) {
if f == 0 { if f == 0 {
return pgtype.Float8{}, nil return pgtype.Float8{}, nil
@ -31,7 +29,7 @@ func (f Float8) Float64Value() (pgtype.Float8, error) {
return pgtype.Float8{Float64: float64(f), Valid: true}, nil return pgtype.Float8{Float64: float64(f), Valid: true}, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (f *Float8) Scan(src any) error { func (f *Float8) Scan(src any) error {
if src == nil { if src == nil {
*f = 0 *f = 0
@ -49,7 +47,7 @@ func (f *Float8) Scan(src any) error {
return nil return nil
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (f Float8) Value() (driver.Value, error) { func (f Float8) Value() (driver.Value, error) {
if f == 0 { if f == 0 {
return nil, nil return nil, nil

View File

@ -12,36 +12,27 @@ import (
type Int2 int16 type Int2 int16
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
func (Int2) SkipUnderlyingTypePlan() {} func (Int2) SkipUnderlyingTypePlan() {}
// ScanInt64 implements the [pgtype.Int64Scanner] interface. // ScanInt64 implements the Int64Scanner interface.
func (dst *Int2) ScanInt64(n pgtype.Int8) error { func (dst *Int2) ScanInt64(n int64, valid bool) error {
if !n.Valid { if !valid {
*dst = 0 *dst = 0
return nil return nil
} }
if n.Int64 < math.MinInt16 { if n < math.MinInt16 {
return fmt.Errorf("%d is less than minimum value for Int2", n.Int64) return fmt.Errorf("%d is greater than maximum value for Int2", n)
} }
if n.Int64 > math.MaxInt16 { if n > math.MaxInt16 {
return fmt.Errorf("%d is greater than maximum value for Int2", n.Int64) return fmt.Errorf("%d is greater than maximum value for Int2", n)
} }
*dst = Int2(n.Int64) *dst = Int2(n)
return nil return nil
} }
// Int64Value implements the [pgtype.Int64Valuer] interface. // Scan implements the database/sql Scanner interface.
func (src Int2) Int64Value() (pgtype.Int8, error) {
if src == 0 {
return pgtype.Int8{}, nil
}
return pgtype.Int8{Int64: int64(src), Valid: true}, nil
}
// Scan implements the [database/sql.Scanner] interface.
func (dst *Int2) Scan(src any) error { func (dst *Int2) Scan(src any) error {
if src == nil { if src == nil {
*dst = 0 *dst = 0
@ -59,7 +50,7 @@ func (dst *Int2) Scan(src any) error {
return nil return nil
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Int2) Value() (driver.Value, error) { func (src Int2) Value() (driver.Value, error) {
if src == 0 { if src == 0 {
return nil, nil return nil, nil
@ -69,36 +60,27 @@ func (src Int2) Value() (driver.Value, error) {
type Int4 int32 type Int4 int32
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
func (Int4) SkipUnderlyingTypePlan() {} func (Int4) SkipUnderlyingTypePlan() {}
// ScanInt64 implements the [pgtype.Int64Scanner] interface. // ScanInt64 implements the Int64Scanner interface.
func (dst *Int4) ScanInt64(n pgtype.Int8) error { func (dst *Int4) ScanInt64(n int64, valid bool) error {
if !n.Valid { if !valid {
*dst = 0 *dst = 0
return nil return nil
} }
if n.Int64 < math.MinInt32 { if n < math.MinInt32 {
return fmt.Errorf("%d is less than minimum value for Int4", n.Int64) return fmt.Errorf("%d is greater than maximum value for Int4", n)
} }
if n.Int64 > math.MaxInt32 { if n > math.MaxInt32 {
return fmt.Errorf("%d is greater than maximum value for Int4", n.Int64) return fmt.Errorf("%d is greater than maximum value for Int4", n)
} }
*dst = Int4(n.Int64) *dst = Int4(n)
return nil return nil
} }
// Int64Value implements the [pgtype.Int64Valuer] interface. // Scan implements the database/sql Scanner interface.
func (src Int4) Int64Value() (pgtype.Int8, error) {
if src == 0 {
return pgtype.Int8{}, nil
}
return pgtype.Int8{Int64: int64(src), Valid: true}, nil
}
// Scan implements the [database/sql.Scanner] interface.
func (dst *Int4) Scan(src any) error { func (dst *Int4) Scan(src any) error {
if src == nil { if src == nil {
*dst = 0 *dst = 0
@ -116,7 +98,7 @@ func (dst *Int4) Scan(src any) error {
return nil return nil
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Int4) Value() (driver.Value, error) { func (src Int4) Value() (driver.Value, error) {
if src == 0 { if src == 0 {
return nil, nil return nil, nil
@ -126,36 +108,27 @@ func (src Int4) Value() (driver.Value, error) {
type Int8 int64 type Int8 int64
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
func (Int8) SkipUnderlyingTypePlan() {} func (Int8) SkipUnderlyingTypePlan() {}
// ScanInt64 implements the [pgtype.Int64Scanner] interface. // ScanInt64 implements the Int64Scanner interface.
func (dst *Int8) ScanInt64(n pgtype.Int8) error { func (dst *Int8) ScanInt64(n int64, valid bool) error {
if !n.Valid { if !valid {
*dst = 0 *dst = 0
return nil return nil
} }
if n.Int64 < math.MinInt64 { if n < math.MinInt64 {
return fmt.Errorf("%d is less than minimum value for Int8", n.Int64) return fmt.Errorf("%d is greater than maximum value for Int8", n)
} }
if n.Int64 > math.MaxInt64 { if n > math.MaxInt64 {
return fmt.Errorf("%d is greater than maximum value for Int8", n.Int64) return fmt.Errorf("%d is greater than maximum value for Int8", n)
} }
*dst = Int8(n.Int64) *dst = Int8(n)
return nil return nil
} }
// Int64Value implements the [pgtype.Int64Valuer] interface. // Scan implements the database/sql Scanner interface.
func (src Int8) Int64Value() (pgtype.Int8, error) {
if src == 0 {
return pgtype.Int8{}, nil
}
return pgtype.Int8{Int64: int64(src), Valid: true}, nil
}
// Scan implements the [database/sql.Scanner] interface.
func (dst *Int8) Scan(src any) error { func (dst *Int8) Scan(src any) error {
if src == nil { if src == nil {
*dst = 0 *dst = 0
@ -173,7 +146,7 @@ func (dst *Int8) Scan(src any) error {
return nil return nil
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Int8) Value() (driver.Value, error) { func (src Int8) Value() (driver.Value, error) {
if src == 0 { if src == 0 {
return nil, nil return nil, nil

View File

@ -12,36 +12,27 @@ import (
<% pg_bit_size = pg_byte_size * 8 %> <% pg_bit_size = pg_byte_size * 8 %>
type Int<%= pg_byte_size %> int<%= pg_bit_size %> type Int<%= pg_byte_size %> int<%= pg_bit_size %>
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
func (Int<%= pg_byte_size %>) SkipUnderlyingTypePlan() {} func (Int<%= pg_byte_size %>) SkipUnderlyingTypePlan() {}
// ScanInt64 implements the [pgtype.Int64Scanner] interface. // ScanInt64 implements the Int64Scanner interface.
func (dst *Int<%= pg_byte_size %>) ScanInt64(n pgtype.Int8) error { func (dst *Int<%= pg_byte_size %>) ScanInt64(n int64, valid bool) error {
if !n.Valid { if !valid {
*dst = 0 *dst = 0
return nil return nil
} }
if n.Int64 < math.MinInt<%= pg_bit_size %> { if n < math.MinInt<%= pg_bit_size %> {
return fmt.Errorf("%d is less than minimum value for Int<%= pg_byte_size %>", n.Int64) return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n)
} }
if n.Int64 > math.MaxInt<%= pg_bit_size %> { if n > math.MaxInt<%= pg_bit_size %> {
return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int64) return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n)
} }
*dst = Int<%= pg_byte_size %>(n.Int64) *dst = Int<%= pg_byte_size %>(n)
return nil return nil
} }
// Int64Value implements the [pgtype.Int64Valuer] interface. // Scan implements the database/sql Scanner interface.
func (src Int<%= pg_byte_size %>) Int64Value() (pgtype.Int8, error) {
if src == 0 {
return pgtype.Int8{}, nil
}
return pgtype.Int8{Int64: int64(src), Valid: true}, nil
}
// Scan implements the [database/sql.Scanner] interface.
func (dst *Int<%= pg_byte_size %>) Scan(src any) error { func (dst *Int<%= pg_byte_size %>) Scan(src any) error {
if src == nil { if src == nil {
*dst = 0 *dst = 0
@ -59,7 +50,7 @@ func (dst *Int<%= pg_byte_size %>) Scan(src any) error {
return nil return nil
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) { func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) {
if src == 0 { if src == 0 {
return nil, nil return nil, nil

View File

@ -8,10 +8,9 @@ import (
type Text string type Text string
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
func (Text) SkipUnderlyingTypePlan() {} func (Text) SkipUnderlyingTypePlan() {}
// ScanText implements the [pgtype.TextScanner] interface. // ScanText implements the TextScanner interface.
func (dst *Text) ScanText(v pgtype.Text) error { func (dst *Text) ScanText(v pgtype.Text) error {
if !v.Valid { if !v.Valid {
*dst = "" *dst = ""
@ -23,7 +22,7 @@ func (dst *Text) ScanText(v pgtype.Text) error {
return nil return nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (dst *Text) Scan(src any) error { func (dst *Text) Scan(src any) error {
if src == nil { if src == nil {
*dst = "" *dst = ""
@ -41,7 +40,7 @@ func (dst *Text) Scan(src any) error {
return nil return nil
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (src Text) Value() (driver.Value, error) { func (src Text) Value() (driver.Value, error) {
if src == "" { if src == "" {
return nil, nil return nil, nil

View File

@ -10,10 +10,8 @@ import (
type Timestamp time.Time type Timestamp time.Time
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
func (Timestamp) SkipUnderlyingTypePlan() {} func (Timestamp) SkipUnderlyingTypePlan() {}
// ScanTimestamp implements the [pgtype.TimestampScanner] interface.
func (ts *Timestamp) ScanTimestamp(v pgtype.Timestamp) error { func (ts *Timestamp) ScanTimestamp(v pgtype.Timestamp) error {
if !v.Valid { if !v.Valid {
*ts = Timestamp{} *ts = Timestamp{}
@ -33,7 +31,6 @@ func (ts *Timestamp) ScanTimestamp(v pgtype.Timestamp) error {
} }
} }
// TimestampValue implements the [pgtype.TimestampValuer] interface.
func (ts Timestamp) TimestampValue() (pgtype.Timestamp, error) { func (ts Timestamp) TimestampValue() (pgtype.Timestamp, error) {
if time.Time(ts).IsZero() { if time.Time(ts).IsZero() {
return pgtype.Timestamp{}, nil return pgtype.Timestamp{}, nil
@ -42,7 +39,7 @@ func (ts Timestamp) TimestampValue() (pgtype.Timestamp, error) {
return pgtype.Timestamp{Time: time.Time(ts), Valid: true}, nil return pgtype.Timestamp{Time: time.Time(ts), Valid: true}, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (ts *Timestamp) Scan(src any) error { func (ts *Timestamp) Scan(src any) error {
if src == nil { if src == nil {
*ts = Timestamp{} *ts = Timestamp{}
@ -60,7 +57,7 @@ func (ts *Timestamp) Scan(src any) error {
return nil return nil
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (ts Timestamp) Value() (driver.Value, error) { func (ts Timestamp) Value() (driver.Value, error) {
if time.Time(ts).IsZero() { if time.Time(ts).IsZero() {
return nil, nil return nil, nil

View File

@ -10,10 +10,8 @@ import (
type Timestamptz time.Time type Timestamptz time.Time
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
func (Timestamptz) SkipUnderlyingTypePlan() {} func (Timestamptz) SkipUnderlyingTypePlan() {}
// ScanTimestamptz implements the [pgtype.TimestamptzScanner] interface.
func (ts *Timestamptz) ScanTimestamptz(v pgtype.Timestamptz) error { func (ts *Timestamptz) ScanTimestamptz(v pgtype.Timestamptz) error {
if !v.Valid { if !v.Valid {
*ts = Timestamptz{} *ts = Timestamptz{}
@ -33,7 +31,6 @@ func (ts *Timestamptz) ScanTimestamptz(v pgtype.Timestamptz) error {
} }
} }
// TimestamptzValue implements the [pgtype.TimestamptzValuer] interface.
func (ts Timestamptz) TimestamptzValue() (pgtype.Timestamptz, error) { func (ts Timestamptz) TimestamptzValue() (pgtype.Timestamptz, error) {
if time.Time(ts).IsZero() { if time.Time(ts).IsZero() {
return pgtype.Timestamptz{}, nil return pgtype.Timestamptz{}, nil
@ -42,7 +39,7 @@ func (ts Timestamptz) TimestamptzValue() (pgtype.Timestamptz, error) {
return pgtype.Timestamptz{Time: time.Time(ts), Valid: true}, nil return pgtype.Timestamptz{Time: time.Time(ts), Valid: true}, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (ts *Timestamptz) Scan(src any) error { func (ts *Timestamptz) Scan(src any) error {
if src == nil { if src == nil {
*ts = Timestamptz{} *ts = Timestamptz{}
@ -60,7 +57,7 @@ func (ts *Timestamptz) Scan(src any) error {
return nil return nil
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (ts Timestamptz) Value() (driver.Value, error) { func (ts Timestamptz) Value() (driver.Value, error) {
if time.Time(ts).IsZero() { if time.Time(ts).IsZero() {
return nil, nil return nil, nil

View File

@ -8,10 +8,9 @@ import (
type UUID [16]byte type UUID [16]byte
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
func (UUID) SkipUnderlyingTypePlan() {} func (UUID) SkipUnderlyingTypePlan() {}
// ScanUUID implements the [pgtype.UUIDScanner] interface. // ScanUUID implements the UUIDScanner interface.
func (u *UUID) ScanUUID(v pgtype.UUID) error { func (u *UUID) ScanUUID(v pgtype.UUID) error {
if !v.Valid { if !v.Valid {
*u = UUID{} *u = UUID{}
@ -23,7 +22,6 @@ func (u *UUID) ScanUUID(v pgtype.UUID) error {
return nil return nil
} }
// UUIDValue implements the [pgtype.UUIDValuer] interface.
func (u UUID) UUIDValue() (pgtype.UUID, error) { func (u UUID) UUIDValue() (pgtype.UUID, error) {
if u == (UUID{}) { if u == (UUID{}) {
return pgtype.UUID{}, nil return pgtype.UUID{}, nil
@ -31,7 +29,7 @@ func (u UUID) UUIDValue() (pgtype.UUID, error) {
return pgtype.UUID{Bytes: u, Valid: true}, nil return pgtype.UUID{Bytes: u, Valid: true}, nil
} }
// Scan implements the [database/sql.Scanner] interface. // Scan implements the database/sql Scanner interface.
func (u *UUID) Scan(src any) error { func (u *UUID) Scan(src any) error {
if src == nil { if src == nil {
*u = UUID{} *u = UUID{}
@ -49,7 +47,7 @@ func (u *UUID) Scan(src any) error {
return nil return nil
} }
// Value implements the [database/sql/driver.Valuer] interface. // Value implements the database/sql/driver Valuer interface.
func (u UUID) Value() (driver.Value, error) { func (u UUID) Value() (driver.Value, error) {
if u == (UUID{}) { if u == (UUID{}) {
return nil, nil return nil, nil

View File

@ -97,8 +97,7 @@ func testCopyFrom(t *testing.T, ctx context.Context, db interface {
execer execer
queryer queryer
copyFromer copyFromer
}, }) {
) {
_, err := db.Exec(ctx, `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`) _, err := db.Exec(ctx, `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`)
require.NoError(t, err) require.NoError(t, err)
@ -142,7 +141,6 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName
// Can't test function equality, so just test that they are set or not. // Can't test function equality, so just test that they are set or not.
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName) assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
assert.Equalf(t, expected.BeforeAcquire == nil, actual.BeforeAcquire == nil, "%s - BeforeAcquire", testName) assert.Equalf(t, expected.BeforeAcquire == nil, actual.BeforeAcquire == nil, "%s - BeforeAcquire", testName)
assert.Equalf(t, expected.PrepareConn == nil, actual.PrepareConn == nil, "%s - PrepareConn", testName)
assert.Equalf(t, expected.AfterRelease == nil, actual.AfterRelease == nil, "%s - AfterRelease", testName) assert.Equalf(t, expected.AfterRelease == nil, actual.AfterRelease == nil, "%s - AfterRelease", testName)
assert.Equalf(t, expected.MaxConnLifetime, actual.MaxConnLifetime, "%s - MaxConnLifetime", testName) assert.Equalf(t, expected.MaxConnLifetime, actual.MaxConnLifetime, "%s - MaxConnLifetime", testName)

View File

@ -2,7 +2,7 @@ package pgxpool
import ( import (
"context" "context"
"errors" "fmt"
"math/rand" "math/rand"
"runtime" "runtime"
"strconv" "strconv"
@ -15,14 +15,12 @@ import (
"github.com/jackc/puddle/v2" "github.com/jackc/puddle/v2"
) )
var ( var defaultMaxConns = int32(4)
defaultMaxConns = int32(4) var defaultMinConns = int32(0)
defaultMinConns = int32(0) var defaultMinIdleConns = int32(0)
defaultMinIdleConns = int32(0) var defaultMaxConnLifetime = time.Hour
defaultMaxConnLifetime = time.Hour var defaultMaxConnIdleTime = time.Minute * 30
defaultMaxConnIdleTime = time.Minute * 30 var defaultHealthCheckPeriod = time.Minute
defaultHealthCheckPeriod = time.Minute
)
type connResource struct { type connResource struct {
conn *pgx.Conn conn *pgx.Conn
@ -86,10 +84,9 @@ type Pool struct {
config *Config config *Config
beforeConnect func(context.Context, *pgx.ConnConfig) error beforeConnect func(context.Context, *pgx.ConnConfig) error
afterConnect func(context.Context, *pgx.Conn) error afterConnect func(context.Context, *pgx.Conn) error
prepareConn func(context.Context, *pgx.Conn) (bool, error) beforeAcquire func(context.Context, *pgx.Conn) bool
afterRelease func(*pgx.Conn) bool afterRelease func(*pgx.Conn) bool
beforeClose func(*pgx.Conn) beforeClose func(*pgx.Conn)
shouldPing func(context.Context, ShouldPingParams) bool
minConns int32 minConns int32
minIdleConns int32 minIdleConns int32
maxConns int32 maxConns int32
@ -107,12 +104,6 @@ type Pool struct {
closeChan chan struct{} closeChan chan struct{}
} }
// ShouldPingParams are the parameters passed to ShouldPing.
type ShouldPingParams struct {
Conn *pgx.Conn
IdleDuration time.Duration
}
// Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be // Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be
// modified. // modified.
type Config struct { type Config struct {
@ -128,23 +119,8 @@ type Config struct {
// BeforeAcquire is called before a connection is acquired from the pool. It must return true to allow the // BeforeAcquire is called before a connection is acquired from the pool. It must return true to allow the
// acquisition or false to indicate that the connection should be destroyed and a different connection should be // acquisition or false to indicate that the connection should be destroyed and a different connection should be
// acquired. // acquired.
//
// Deprecated: Use PrepareConn instead. If both PrepareConn and BeforeAcquire are set, PrepareConn will take
// precedence, ignoring BeforeAcquire.
BeforeAcquire func(context.Context, *pgx.Conn) bool BeforeAcquire func(context.Context, *pgx.Conn) bool
// PrepareConn is called before a connection is acquired from the pool. If this function returns true, the connection
// is considered valid, otherwise the connection is destroyed. If the function returns a non-nil error, the instigating
// query will fail with the returned error.
//
// Specifically, this means that:
//
// - If it returns true and a nil error, the query proceeds as normal.
// - If it returns true and an error, the connection will be returned to the pool, and the instigating query will fail with the returned error.
// - If it returns false, and an error, the connection will be destroyed, and the query will fail with the returned error.
// - If it returns false and a nil error, the connection will be destroyed, and the instigating query will be retried on a new connection.
PrepareConn func(context.Context, *pgx.Conn) (bool, error)
// AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to // AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to
// return the connection to the pool or false to destroy the connection. // return the connection to the pool or false to destroy the connection.
AfterRelease func(*pgx.Conn) bool AfterRelease func(*pgx.Conn) bool
@ -152,10 +128,6 @@ type Config struct {
// BeforeClose is called right before a connection is closed and removed from the pool. // BeforeClose is called right before a connection is closed and removed from the pool.
BeforeClose func(*pgx.Conn) BeforeClose func(*pgx.Conn)
// ShouldPing is called after a connection is acquired from the pool. If it returns true, the connection is pinged to check for liveness.
// If this func is not set, the default behavior is to ping connections that have been idle for at least 1 second.
ShouldPing func(context.Context, ShouldPingParams) bool
// MaxConnLifetime is the duration since creation after which a connection will be automatically closed. // MaxConnLifetime is the duration since creation after which a connection will be automatically closed.
MaxConnLifetime time.Duration MaxConnLifetime time.Duration
@ -218,18 +190,11 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
panic("config must be created by ParseConfig") panic("config must be created by ParseConfig")
} }
prepareConn := config.PrepareConn
if prepareConn == nil && config.BeforeAcquire != nil {
prepareConn = func(ctx context.Context, conn *pgx.Conn) (bool, error) {
return config.BeforeAcquire(ctx, conn), nil
}
}
p := &Pool{ p := &Pool{
config: config, config: config,
beforeConnect: config.BeforeConnect, beforeConnect: config.BeforeConnect,
afterConnect: config.AfterConnect, afterConnect: config.AfterConnect,
prepareConn: prepareConn, beforeAcquire: config.BeforeAcquire,
afterRelease: config.AfterRelease, afterRelease: config.AfterRelease,
beforeClose: config.BeforeClose, beforeClose: config.BeforeClose,
minConns: config.MinConns, minConns: config.MinConns,
@ -251,14 +216,6 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
p.releaseTracer = t p.releaseTracer = t
} }
if config.ShouldPing != nil {
p.shouldPing = config.ShouldPing
} else {
p.shouldPing = func(ctx context.Context, params ShouldPingParams) bool {
return params.IdleDuration > time.Second
}
}
var err error var err error
p.p, err = puddle.NewPool( p.p, err = puddle.NewPool(
&puddle.Config[*connResource]{ &puddle.Config[*connResource]{
@ -364,10 +321,10 @@ func ParseConfig(connString string) (*Config, error) {
delete(connConfig.Config.RuntimeParams, "pool_max_conns") delete(connConfig.Config.RuntimeParams, "pool_max_conns")
n, err := strconv.ParseInt(s, 10, 32) n, err := strconv.ParseInt(s, 10, 32)
if err != nil { if err != nil {
return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conns", err) return nil, fmt.Errorf("cannot parse pool_max_conns: %w", err)
} }
if n < 1 { if n < 1 {
return nil, pgconn.NewParseConfigError(connString, "pool_max_conns too small", err) return nil, fmt.Errorf("pool_max_conns too small: %d", n)
} }
config.MaxConns = int32(n) config.MaxConns = int32(n)
} else { } else {
@ -381,7 +338,7 @@ func ParseConfig(connString string) (*Config, error) {
delete(connConfig.Config.RuntimeParams, "pool_min_conns") delete(connConfig.Config.RuntimeParams, "pool_min_conns")
n, err := strconv.ParseInt(s, 10, 32) n, err := strconv.ParseInt(s, 10, 32)
if err != nil { if err != nil {
return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_min_conns", err) return nil, fmt.Errorf("cannot parse pool_min_conns: %w", err)
} }
config.MinConns = int32(n) config.MinConns = int32(n)
} else { } else {
@ -392,7 +349,7 @@ func ParseConfig(connString string) (*Config, error) {
delete(connConfig.Config.RuntimeParams, "pool_min_idle_conns") delete(connConfig.Config.RuntimeParams, "pool_min_idle_conns")
n, err := strconv.ParseInt(s, 10, 32) n, err := strconv.ParseInt(s, 10, 32)
if err != nil { if err != nil {
return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_min_idle_conns", err) return nil, fmt.Errorf("cannot parse pool_min_idle_conns: %w", err)
} }
config.MinIdleConns = int32(n) config.MinIdleConns = int32(n)
} else { } else {
@ -403,7 +360,7 @@ func ParseConfig(connString string) (*Config, error) {
delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime") delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime")
d, err := time.ParseDuration(s) d, err := time.ParseDuration(s)
if err != nil { if err != nil {
return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_lifetime", err) return nil, fmt.Errorf("invalid pool_max_conn_lifetime: %w", err)
} }
config.MaxConnLifetime = d config.MaxConnLifetime = d
} else { } else {
@ -414,7 +371,7 @@ func ParseConfig(connString string) (*Config, error) {
delete(connConfig.Config.RuntimeParams, "pool_max_conn_idle_time") delete(connConfig.Config.RuntimeParams, "pool_max_conn_idle_time")
d, err := time.ParseDuration(s) d, err := time.ParseDuration(s)
if err != nil { if err != nil {
return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_idle_time", err) return nil, fmt.Errorf("invalid pool_max_conn_idle_time: %w", err)
} }
config.MaxConnIdleTime = d config.MaxConnIdleTime = d
} else { } else {
@ -425,7 +382,7 @@ func ParseConfig(connString string) (*Config, error) {
delete(connConfig.Config.RuntimeParams, "pool_health_check_period") delete(connConfig.Config.RuntimeParams, "pool_health_check_period")
d, err := time.ParseDuration(s) d, err := time.ParseDuration(s)
if err != nil { if err != nil {
return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_health_check_period", err) return nil, fmt.Errorf("invalid pool_health_check_period: %w", err)
} }
config.HealthCheckPeriod = d config.HealthCheckPeriod = d
} else { } else {
@ -436,7 +393,7 @@ func ParseConfig(connString string) (*Config, error) {
delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime_jitter") delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime_jitter")
d, err := time.ParseDuration(s) d, err := time.ParseDuration(s)
if err != nil { if err != nil {
return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_lifetime_jitter", err) return nil, fmt.Errorf("invalid pool_max_conn_lifetime_jitter: %w", err)
} }
config.MaxConnLifetimeJitter = d config.MaxConnLifetimeJitter = d
} }
@ -588,10 +545,7 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) {
}() }()
} }
// Try to acquire from the connection pool up to maxConns + 1 times, so that for {
// any that fatal errors would empty the pool and still at least try 1 fresh
// connection.
for range p.maxConns + 1 {
res, err := p.p.Acquire(ctx) res, err := p.p.Acquire(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
@ -599,8 +553,7 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) {
cr := res.Value() cr := res.Value()
shouldPingParams := ShouldPingParams{Conn: cr.conn, IdleDuration: res.IdleDuration()} if res.IdleDuration() > time.Second {
if p.shouldPing(ctx, shouldPingParams) {
err := cr.conn.Ping(ctx) err := cr.conn.Ping(ctx)
if err != nil { if err != nil {
res.Destroy() res.Destroy()
@ -608,25 +561,12 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) {
} }
} }
if p.prepareConn != nil { if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) {
ok, err := p.prepareConn(ctx, cr.conn) return cr.getConn(p, res), nil
if !ok {
res.Destroy()
}
if err != nil {
if ok {
res.Release()
}
return nil, err
}
if !ok {
continue
}
} }
return cr.getConn(p, res), nil res.Destroy()
} }
return nil, errors.New("pgxpool: detected infinite loop acquiring connection; likely bug in PrepareConn or BeforeAcquire hook")
} }
// AcquireFunc acquires a *Conn and calls f with that *Conn. ctx will only affect the Acquire. It has no effect on the // AcquireFunc acquires a *Conn and calls f with that *Conn. ctx will only affect the Acquire. It has no effect on the
@ -649,14 +589,11 @@ func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn {
conns := make([]*Conn, 0, len(resources)) conns := make([]*Conn, 0, len(resources))
for _, res := range resources { for _, res := range resources {
cr := res.Value() cr := res.Value()
if p.prepareConn != nil { if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) {
ok, err := p.prepareConn(ctx, cr.conn) conns = append(conns, cr.getConn(p, res))
if !ok || err != nil { } else {
res.Destroy() res.Destroy()
continue
}
} }
conns = append(conns, cr.getConn(p, res))
} }
return conns return conns

View File

@ -204,47 +204,6 @@ func TestPoolAcquireChecksIdleConns(t *testing.T) {
require.NotContains(t, pids, cPID) require.NotContains(t, pids, cPID)
} }
func TestPoolAcquireChecksIdleConnsWithShouldPing(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
controllerConn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer controllerConn.Close(ctx)
config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
// Replace the default ShouldPing func
var shouldPingLastCalledWith *pgxpool.ShouldPingParams
config.ShouldPing = func(ctx context.Context, params pgxpool.ShouldPingParams) bool {
shouldPingLastCalledWith = &params
return false
}
pool, err := pgxpool.NewWithConfig(ctx, config)
require.NoError(t, err)
defer pool.Close()
c, err := pool.Acquire(ctx)
require.NoError(t, err)
c.Release()
time.Sleep(time.Millisecond * 200)
c, err = pool.Acquire(ctx)
require.NoError(t, err)
conn := c.Conn()
require.NotNil(t, shouldPingLastCalledWith)
assert.Equal(t, conn, shouldPingLastCalledWith.Conn)
assert.InDelta(t, time.Millisecond*200, shouldPingLastCalledWith.IdleDuration, float64(time.Millisecond*100))
c.Release()
}
func TestPoolAcquireFunc(t *testing.T) { func TestPoolAcquireFunc(t *testing.T) {
t.Parallel() t.Parallel()
@ -371,64 +330,6 @@ func TestPoolBeforeAcquire(t *testing.T) {
assert.EqualValues(t, 12, acquireAttempts) assert.EqualValues(t, 12, acquireAttempts)
} }
func TestPoolPrepareConn(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
acquireAttempts := 0
config.PrepareConn = func(context.Context, *pgx.Conn) (bool, error) {
acquireAttempts++
var err error
if acquireAttempts%3 == 0 {
err = errors.New("PrepareConn error")
}
return acquireAttempts%2 == 0, err
}
db, err := pgxpool.NewWithConfig(ctx, config)
require.NoError(t, err)
t.Cleanup(db.Close)
var errorCount int
conns := make([]*pgxpool.Conn, 0, 4)
for {
conn, err := db.Acquire(ctx)
if err != nil {
errorCount++
continue
}
conns = append(conns, conn)
if len(conns) == 4 {
break
}
}
const wantErrorCount = 3
assert.Equal(t, wantErrorCount, errorCount, "Acquire() should have failed %d times", wantErrorCount)
for _, c := range conns {
c.Release()
}
waitForReleaseToComplete()
assert.EqualValues(t, len(conns)*2+wantErrorCount-1, acquireAttempts)
conns = db.AcquireAllIdle(ctx)
assert.Len(t, conns, 1)
for _, c := range conns {
c.Release()
}
waitForReleaseToComplete()
assert.EqualValues(t, 14, acquireAttempts)
}
func TestPoolAfterRelease(t *testing.T) { func TestPoolAfterRelease(t *testing.T) {
t.Parallel() t.Parallel()
@ -776,6 +677,7 @@ func TestPoolQuery(t *testing.T) {
stats = pool.Stat() stats = pool.Stat()
assert.EqualValues(t, 0, stats.AcquiredConns()) assert.EqualValues(t, 0, stats.AcquiredConns())
assert.EqualValues(t, 1, stats.TotalConns()) assert.EqualValues(t, 1, stats.TotalConns())
} }
func TestPoolQueryRow(t *testing.T) { func TestPoolQueryRow(t *testing.T) {
@ -1180,9 +1082,9 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) {
acquireAttempts := int64(0) acquireAttempts := int64(0)
connectAttempts := int64(0) connectAttempts := int64(0)
config.PrepareConn = func(ctx context.Context, conn *pgx.Conn) (bool, error) { config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool {
atomic.AddInt64(&acquireAttempts, 1) atomic.AddInt64(&acquireAttempts, 1)
return true, nil return true
} }
config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error { config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error {
atomic.AddInt64(&connectAttempts, 1) atomic.AddInt64(&connectAttempts, 1)
@ -1203,6 +1105,7 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) {
} }
t.Fatal("did not reach min pool size") t.Fatal("did not reach min pool size")
} }
func TestPoolSendBatchBatchCloseTwice(t *testing.T) { func TestPoolSendBatchBatchCloseTwice(t *testing.T) {

View File

@ -568,6 +568,7 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) {
ensureConnValid(t, conn) ensureConnValid(t, conn)
}() }()
} }
} }
func TestQueryEncodeError(t *testing.T) { func TestQueryEncodeError(t *testing.T) {
@ -2207,6 +2208,7 @@ insert into products (name, price) values
} }
rows, err := conn.Query(ctx, "select name, price from products where price < $1 order by price desc", 12) rows, err := conn.Query(ctx, "select name, price from products where price < $1 order by price desc", 12)
// It is unnecessary to check err. If an error occurred it will be returned by rows.Err() later. But in rare // It is unnecessary to check err. If an error occurred it will be returned by rows.Err() later. But in rare
// cases it may be useful to detect the error as early as possible. // cases it may be useful to detect the error as early as possible.
if err != nil { if err != nil {

23
rows.go
View File

@ -41,19 +41,22 @@ type Rows interface {
// when there was an error executing the query. // when there was an error executing the query.
FieldDescriptions() []pgconn.FieldDescription FieldDescriptions() []pgconn.FieldDescription
// Next prepares the next row for reading. It returns true if there is another row and false if no more rows are // Next prepares the next row for reading. It returns true if there is another
// available or a fatal error has occurred. It automatically closes rows upon returning false (whether due to all rows // row and false if no more rows are available or a fatal error has occurred.
// having been read or due to an error). // It automatically closes rows when all rows are read.
// //
// Callers should check rows.Err() after rows.Next() returns false to detect whether result-set reading ended // Callers should check rows.Err() after rows.Next() returns false to detect
// prematurely due to an error. See Conn.Query for details. // whether result-set reading ended prematurely due to an error. See
// Conn.Query for details.
// //
// For simpler error handling, consider using the higher-level pgx v5 CollectRows() and ForEachRow() helpers instead. // For simpler error handling, consider using the higher-level pgx v5
// CollectRows() and ForEachRow() helpers instead.
Next() bool Next() bool
// Scan reads the values from the current row into dest values positionally. dest can include pointers to core types, // Scan reads the values from the current row into dest values positionally.
// values implementing the Scanner interface, and nil. nil will skip the value entirely. It is an error to call Scan // dest can include pointers to core types, values implementing the Scanner
// without first calling Next() and checking that it returned true. Rows is automatically closed upon error. // interface, and nil. nil will skip the value entirely. It is an error to
// call Scan without first calling Next() and checking that it returned true.
Scan(dest ...any) error Scan(dest ...any) error
// Values returns the decoded row values. As with Scan(), it is an error to // Values returns the decoded row values. As with Scan(), it is an error to
@ -560,7 +563,7 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error {
return nil return nil
} }
// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number of public fields as row // RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row
// has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then the field will be // has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then the field will be
// ignored. // ignored.
func RowToStructByPos[T any](row CollectableRow) (T, error) { func RowToStructByPos[T any](row CollectableRow) (T, error) {

View File

@ -471,8 +471,7 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
args := make([]any, len(argsV)) args := namedValueToInterface(argsV)
convertNamedArguments(args, argsV)
commandTag, err := c.conn.Exec(ctx, query, args...) commandTag, err := c.conn.Exec(ctx, query, args...)
// if we got a network error before we had a chance to send the query, retry // if we got a network error before we had a chance to send the query, retry
@ -489,9 +488,8 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
args := make([]any, 1+len(argsV)) args := []any{databaseSQLResultFormats}
args[0] = databaseSQLResultFormats args = append(args, namedValueToInterface(argsV)...)
convertNamedArguments(args[1:], argsV)
rows, err := c.conn.Query(ctx, query, args...) rows, err := c.conn.Query(ctx, query, args...)
if err != nil { if err != nil {
@ -850,14 +848,28 @@ func (r *Rows) Next(dest []driver.Value) error {
return nil return nil
} }
func convertNamedArguments(args []any, argsV []driver.NamedValue) { func valueToInterface(argsV []driver.Value) []any {
for i, v := range argsV { args := make([]any, 0, len(argsV))
if v.Value != nil { for _, v := range argsV {
args[i] = v.Value.(any) if v != nil {
args = append(args, v.(any))
} else { } else {
args[i] = nil args = append(args, nil)
} }
} }
return args
}
func namedValueToInterface(argsV []driver.NamedValue) []any {
args := make([]any, 0, len(argsV))
for _, v := range argsV {
if v.Value != nil {
args = append(args, v.Value.(any))
} else {
args = append(args, nil)
}
}
return args
} }
type wrapTx struct { type wrapTx struct {

View File

@ -161,6 +161,7 @@ func writeEncryptedPrivateKey(path string, privateKey *rsa.PrivateKey, password
} }
return nil return nil
} }
func writeCertificate(path string, certBytes []byte) error { func writeCertificate(path string, certBytes []byte) error {

View File

@ -104,7 +104,7 @@ func logQueryArgs(args []any) []any {
} }
case string: case string:
if len(v) > 64 { if len(v) > 64 {
l := 0 var l = 0
for w := 0; l < 64; l += w { for w := 0; l < 64; l += w {
_, w = utf8.DecodeRuneInString(v[l:]) _, w = utf8.DecodeRuneInString(v[l:])
} }

View File

@ -362,6 +362,7 @@ func TestLogBatchStatementsOnExec(t *testing.T) {
assert.Equal(t, "BatchQuery", logger.logs[1].msg) assert.Equal(t, "BatchQuery", logger.logs[1].msg)
assert.Equal(t, "drop table foo", logger.logs[1].data["sql"]) assert.Equal(t, "drop table foo", logger.logs[1].data["sql"])
assert.Equal(t, "BatchClose", logger.logs[2].msg) assert.Equal(t, "BatchClose", logger.logs[2].msg)
}) })
} }

View File

@ -117,6 +117,7 @@ func TestJSONAndJSONBTranscodeExtendedOnly(t *testing.T) {
testJSONInt16ArrayFailureDueToOverflow(t, conn, typename) testJSONInt16ArrayFailureDueToOverflow(t, conn, typename)
testJSONStruct(t, conn, typename) testJSONStruct(t, conn, typename)
} }
} }
func testJSONString(t testing.TB, conn *pgx.Conn, typename string) { func testJSONString(t testing.TB, conn *pgx.Conn, typename string) {
@ -596,9 +597,7 @@ func TestArrayDecoding(t *testing.T) {
assert func(testing.TB, any, any) assert func(testing.TB, any, any)
}{ }{
{ {
"select $1::bool[]", "select $1::bool[]", []bool{true, false, true}, &[]bool{},
[]bool{true, false, true},
&[]bool{},
func(t testing.TB, query, scan any) { func(t testing.TB, query, scan any) {
if !reflect.DeepEqual(query, *(scan.(*[]bool))) { if !reflect.DeepEqual(query, *(scan.(*[]bool))) {
t.Errorf("failed to encode bool[]") t.Errorf("failed to encode bool[]")
@ -606,9 +605,7 @@ func TestArrayDecoding(t *testing.T) {
}, },
}, },
{ {
"select $1::smallint[]", "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{},
[]int16{2, 4, 484, 32767},
&[]int16{},
func(t testing.TB, query, scan any) { func(t testing.TB, query, scan any) {
if !reflect.DeepEqual(query, *(scan.(*[]int16))) { if !reflect.DeepEqual(query, *(scan.(*[]int16))) {
t.Errorf("failed to encode smallint[]") t.Errorf("failed to encode smallint[]")
@ -616,9 +613,7 @@ func TestArrayDecoding(t *testing.T) {
}, },
}, },
{ {
"select $1::smallint[]", "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{},
[]uint16{2, 4, 484, 32767},
&[]uint16{},
func(t testing.TB, query, scan any) { func(t testing.TB, query, scan any) {
if !reflect.DeepEqual(query, *(scan.(*[]uint16))) { if !reflect.DeepEqual(query, *(scan.(*[]uint16))) {
t.Errorf("failed to encode smallint[]") t.Errorf("failed to encode smallint[]")
@ -626,9 +621,7 @@ func TestArrayDecoding(t *testing.T) {
}, },
}, },
{ {
"select $1::int[]", "select $1::int[]", []int32{2, 4, 484}, &[]int32{},
[]int32{2, 4, 484},
&[]int32{},
func(t testing.TB, query, scan any) { func(t testing.TB, query, scan any) {
if !reflect.DeepEqual(query, *(scan.(*[]int32))) { if !reflect.DeepEqual(query, *(scan.(*[]int32))) {
t.Errorf("failed to encode int[]") t.Errorf("failed to encode int[]")
@ -636,9 +629,7 @@ func TestArrayDecoding(t *testing.T) {
}, },
}, },
{ {
"select $1::int[]", "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{},
[]uint32{2, 4, 484, 2147483647},
&[]uint32{},
func(t testing.TB, query, scan any) { func(t testing.TB, query, scan any) {
if !reflect.DeepEqual(query, *(scan.(*[]uint32))) { if !reflect.DeepEqual(query, *(scan.(*[]uint32))) {
t.Errorf("failed to encode int[]") t.Errorf("failed to encode int[]")
@ -646,9 +637,7 @@ func TestArrayDecoding(t *testing.T) {
}, },
}, },
{ {
"select $1::bigint[]", "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{},
[]int64{2, 4, 484, 9223372036854775807},
&[]int64{},
func(t testing.TB, query, scan any) { func(t testing.TB, query, scan any) {
if !reflect.DeepEqual(query, *(scan.(*[]int64))) { if !reflect.DeepEqual(query, *(scan.(*[]int64))) {
t.Errorf("failed to encode bigint[]") t.Errorf("failed to encode bigint[]")
@ -656,9 +645,7 @@ func TestArrayDecoding(t *testing.T) {
}, },
}, },
{ {
"select $1::bigint[]", "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{},
[]uint64{2, 4, 484, 9223372036854775807},
&[]uint64{},
func(t testing.TB, query, scan any) { func(t testing.TB, query, scan any) {
if !reflect.DeepEqual(query, *(scan.(*[]uint64))) { if !reflect.DeepEqual(query, *(scan.(*[]uint64))) {
t.Errorf("failed to encode bigint[]") t.Errorf("failed to encode bigint[]")
@ -666,9 +653,7 @@ func TestArrayDecoding(t *testing.T) {
}, },
}, },
{ {
"select $1::text[]", "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{},
[]string{"it's", "over", "9000!"},
&[]string{},
func(t testing.TB, query, scan any) { func(t testing.TB, query, scan any) {
if !reflect.DeepEqual(query, *(scan.(*[]string))) { if !reflect.DeepEqual(query, *(scan.(*[]string))) {
t.Errorf("failed to encode text[]") t.Errorf("failed to encode text[]")
@ -676,9 +661,7 @@ func TestArrayDecoding(t *testing.T) {
}, },
}, },
{ {
"select $1::timestamptz[]", "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
[]time.Time{time.Unix(323232, 0), time.Unix(3239949334, 0o0)},
&[]time.Time{},
func(t testing.TB, query, scan any) { func(t testing.TB, query, scan any) {
queryTimeSlice := query.([]time.Time) queryTimeSlice := query.([]time.Time)
scanTimeSlice := *(scan.(*[]time.Time)) scanTimeSlice := *(scan.(*[]time.Time))
@ -689,9 +672,7 @@ func TestArrayDecoding(t *testing.T) {
}, },
}, },
{ {
"select $1::bytea[]", "select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{},
[][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}},
&[][]byte{},
func(t testing.TB, query, scan any) { func(t testing.TB, query, scan any) {
queryBytesSliceSlice := query.([][]byte) queryBytesSliceSlice := query.([][]byte)
scanBytesSliceSlice := *(scan.(*[][]byte)) scanBytesSliceSlice := *(scan.(*[][]byte))