mirror of
https://github.com/jackc/pgx.git
synced 2025-09-04 19:37:10 +00:00
Compare commits
46 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
95fc31294f | ||
|
5534fa9a02 | ||
|
a295d68811 | ||
|
03f32c06bd | ||
|
82fbe49fec | ||
|
594d9d65dc | ||
|
5a18241971 | ||
|
cc34da5884 | ||
|
dd81f81e2f | ||
|
839acbaf18 | ||
|
d1a00a6cd4 | ||
|
81c0db4f49 | ||
|
1516fb8125 | ||
|
b1ef6d90c0 | ||
|
59c73af6bb | ||
|
248afe61b1 | ||
|
88500ac027 | ||
|
f1c8fcd5c2 | ||
|
fc289cbbe8 | ||
|
5cb495fb94 | ||
|
562761a083 | ||
|
fce1a04dbf | ||
|
25cba15299 | ||
|
cec5ebac5b | ||
|
f43091fc80 | ||
|
a11da9a629 | ||
|
2f77a63ce2 | ||
|
39b85ce8d1 | ||
|
e04a6de072 | ||
|
c39a0608a3 | ||
|
f661c47dc8 | ||
|
e68ff102de | ||
|
30ff631878 | ||
|
69934dcd95 | ||
|
ecc9203ef4 | ||
|
33163eefca | ||
|
d2ee7464e8 | ||
|
1320d13f8a | ||
|
f118bb6033 | ||
|
b0572f79e6 | ||
|
4015a0c123 | ||
|
48d27a9fff | ||
|
fc334e4c75 | ||
|
3f5509fe98 | ||
|
de806a11e7 | ||
|
ce13266e90 |
21
.golangci.yml
Normal file
21
.golangci.yml
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
# See for configurations: https://golangci-lint.run/usage/configuration/
|
||||||
|
version: 2
|
||||||
|
|
||||||
|
# See: https://golangci-lint.run/usage/formatters/
|
||||||
|
formatters:
|
||||||
|
default: none
|
||||||
|
enable:
|
||||||
|
- gofmt # https://pkg.go.dev/cmd/gofmt
|
||||||
|
- gofumpt # https://github.com/mvdan/gofumpt
|
||||||
|
|
||||||
|
settings:
|
||||||
|
gofmt:
|
||||||
|
simplify: true # Simplify code: gofmt with `-s` option.
|
||||||
|
|
||||||
|
gofumpt:
|
||||||
|
# Module path which contains the source code being formatted.
|
||||||
|
# Default: ""
|
||||||
|
module-path: github.com/jackc/pgx/v5 # Should match with module in go.mod
|
||||||
|
# Choose whether to use the extra rules.
|
||||||
|
# Default: false
|
||||||
|
extra-rules: true
|
@ -127,6 +127,7 @@ pgerrcode contains constants for the PostgreSQL error codes.
|
|||||||
## Adapters for 3rd Party Tracers
|
## Adapters for 3rd Party Tracers
|
||||||
|
|
||||||
* [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer)
|
* [github.com/jackhopner/pgx-xray-tracer](https://github.com/jackhopner/pgx-xray-tracer)
|
||||||
|
* [github.com/exaring/otelpgx](https://github.com/exaring/otelpgx)
|
||||||
|
|
||||||
## Adapters for 3rd Party Loggers
|
## Adapters for 3rd Party Loggers
|
||||||
|
|
||||||
@ -184,3 +185,7 @@ Simple Golang implementation for transactional outbox pattern for PostgreSQL usi
|
|||||||
### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy)
|
### [https://github.com/Arlandaren/pgxWrappy](https://github.com/Arlandaren/pgxWrappy)
|
||||||
|
|
||||||
Simplifies working with the pgx library, providing convenient scanning of nested structures.
|
Simplifies working with the pgx library, providing convenient scanning of nested structures.
|
||||||
|
|
||||||
|
## [https://github.com/KoNekoD/pgx-colon-query-rewriter](https://github.com/KoNekoD/pgx-colon-query-rewriter)
|
||||||
|
|
||||||
|
Implementation of the pgx query rewriter to use ':' instead of '@' in named query parameters.
|
||||||
|
32
batch.go
32
batch.go
@ -43,6 +43,10 @@ func (qq *QueuedQuery) QueryRow(fn func(row Row) error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Exec sets fn to be called when the response to qq is received.
|
// Exec sets fn to be called when the response to qq is received.
|
||||||
|
//
|
||||||
|
// Note: for simple batch insert uses where it is not required to handle
|
||||||
|
// each potential error individually, it's sufficient to not set any callbacks,
|
||||||
|
// and just handle the return value of BatchResults.Close.
|
||||||
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
|
func (qq *QueuedQuery) Exec(fn func(ct pgconn.CommandTag) error) {
|
||||||
qq.Fn = func(br BatchResults) error {
|
qq.Fn = func(br BatchResults) error {
|
||||||
ct, err := br.Exec()
|
ct, err := br.Exec()
|
||||||
@ -83,7 +87,7 @@ func (b *Batch) Len() int {
|
|||||||
|
|
||||||
type BatchResults interface {
|
type BatchResults interface {
|
||||||
// Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer
|
// Exec reads the results from the next query in the batch as if the query has been sent with Conn.Exec. Prefer
|
||||||
// calling Exec on the QueuedQuery.
|
// calling Exec on the QueuedQuery, or just calling Close.
|
||||||
Exec() (pgconn.CommandTag, error)
|
Exec() (pgconn.CommandTag, error)
|
||||||
|
|
||||||
// Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer
|
// Query reads the results from the next query in the batch as if the query has been sent with Conn.Query. Prefer
|
||||||
@ -98,6 +102,9 @@ type BatchResults interface {
|
|||||||
// QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an
|
// QueuedQuery.Query, QueuedQuery.QueryRow, or QueuedQuery.Exec will be called. If a callback function returns an
|
||||||
// error or the batch encounters an error subsequent callback functions will not be called.
|
// error or the batch encounters an error subsequent callback functions will not be called.
|
||||||
//
|
//
|
||||||
|
// For simple batch inserts inside a transaction or similar queries, it's sufficient to not set any callbacks,
|
||||||
|
// and just handle the return value of Close.
|
||||||
|
//
|
||||||
// Close must be called before the underlying connection can be used again. Any error that occurred during a batch
|
// Close must be called before the underlying connection can be used again. Any error that occurred during a batch
|
||||||
// operation may have made it impossible to resyncronize the connection with the server. In this case the underlying
|
// operation may have made it impossible to resyncronize the connection with the server. In this case the underlying
|
||||||
// connection will have been closed.
|
// connection will have been closed.
|
||||||
@ -207,7 +214,6 @@ func (br *batchResults) Query() (Rows, error) {
|
|||||||
func (br *batchResults) QueryRow() Row {
|
func (br *batchResults) QueryRow() Row {
|
||||||
rows, _ := br.Query()
|
rows, _ := br.Query()
|
||||||
return (*connRow)(rows.(*baseRows))
|
return (*connRow)(rows.(*baseRows))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
|
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
|
||||||
@ -220,6 +226,8 @@ func (br *batchResults) Close() error {
|
|||||||
}
|
}
|
||||||
br.endTraced = true
|
br.endTraced = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
invalidateCachesOnBatchResultsError(br.conn, br.b, br.err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if br.err != nil {
|
if br.err != nil {
|
||||||
@ -378,7 +386,6 @@ func (br *pipelineBatchResults) Query() (Rows, error) {
|
|||||||
func (br *pipelineBatchResults) QueryRow() Row {
|
func (br *pipelineBatchResults) QueryRow() Row {
|
||||||
rows, _ := br.Query()
|
rows, _ := br.Query()
|
||||||
return (*connRow)(rows.(*baseRows))
|
return (*connRow)(rows.(*baseRows))
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
|
// Close closes the batch operation. Any error that occurred during a batch operation may have made it impossible to
|
||||||
@ -391,6 +398,8 @@ func (br *pipelineBatchResults) Close() error {
|
|||||||
}
|
}
|
||||||
br.endTraced = true
|
br.endTraced = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
invalidateCachesOnBatchResultsError(br.conn, br.b, br.err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if br.err == nil && br.lastRows != nil && br.lastRows.err != nil {
|
if br.err == nil && br.lastRows != nil && br.lastRows.err != nil {
|
||||||
@ -441,3 +450,20 @@ func (br *pipelineBatchResults) nextQueryAndArgs() (query string, args []any, er
|
|||||||
br.qqIdx++
|
br.qqIdx++
|
||||||
return bi.SQL, bi.Arguments, nil
|
return bi.SQL, bi.Arguments, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// invalidates statement and description caches on batch results error
|
||||||
|
func invalidateCachesOnBatchResultsError(conn *Conn, b *Batch, err error) {
|
||||||
|
if err != nil && conn != nil && b != nil {
|
||||||
|
if sc := conn.statementCache; sc != nil {
|
||||||
|
for _, bi := range b.QueuedQueries {
|
||||||
|
sc.Invalidate(bi.SQL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if sc := conn.descriptionCache; sc != nil {
|
||||||
|
for _, bi := range b.QueuedQueries {
|
||||||
|
sc.Invalidate(bi.SQL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -488,7 +488,6 @@ func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
batch := &pgx.Batch{}
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("select n from generate_series(0,5) n")
|
batch.Queue("select n from generate_series(0,5) n")
|
||||||
batch.Queue("select n from generate_series(0,5) n")
|
batch.Queue("select n from generate_series(0,5) n")
|
||||||
@ -539,7 +538,6 @@ func TestConnSendBatchCloseRowsPartiallyRead(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -550,7 +548,6 @@ func TestConnSendBatchQueryError(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
batch := &pgx.Batch{}
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0")
|
batch.Queue("select n from generate_series(0,5) n where 100/(5-n) > 0")
|
||||||
batch.Queue("select n from generate_series(0,5) n")
|
batch.Queue("select n from generate_series(0,5) n")
|
||||||
@ -580,7 +577,6 @@ func TestConnSendBatchQueryError(t *testing.T) {
|
|||||||
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
|
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "22012") {
|
||||||
t.Errorf("br.Close() => %v, want error code %v", err, 22012)
|
t.Errorf("br.Close() => %v, want error code %v", err, 22012)
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -591,7 +587,6 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
batch := &pgx.Batch{}
|
batch := &pgx.Batch{}
|
||||||
batch.Queue("select 1 1")
|
batch.Queue("select 1 1")
|
||||||
|
|
||||||
@ -607,7 +602,6 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("Expected error")
|
t.Error("Expected error")
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -618,7 +612,6 @@ func TestConnSendBatchQueryRowInsert(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
sql := `create temporary table ledger(
|
sql := `create temporary table ledger(
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
description varchar not null,
|
description varchar not null,
|
||||||
@ -647,7 +640,6 @@ func TestConnSendBatchQueryRowInsert(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
br.Close()
|
br.Close()
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -658,7 +650,6 @@ func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
sql := `create temporary table ledger(
|
sql := `create temporary table ledger(
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
description varchar not null,
|
description varchar not null,
|
||||||
@ -687,7 +678,6 @@ func TestConnSendBatchQueryPartialReadInsert(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
br.Close()
|
br.Close()
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -698,7 +688,6 @@ func TestTxSendBatch(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
sql := `create temporary table ledger1(
|
sql := `create temporary table ledger1(
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
description varchar not null
|
description varchar not null
|
||||||
@ -757,7 +746,6 @@ func TestTxSendBatch(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -768,7 +756,6 @@ func TestTxSendBatchRollback(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
sql := `create temporary table ledger1(
|
sql := `create temporary table ledger1(
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
description varchar not null
|
description varchar not null
|
||||||
@ -795,7 +782,6 @@ func TestTxSendBatchRollback(t *testing.T) {
|
|||||||
if count != 0 {
|
if count != 0 {
|
||||||
t.Errorf("count => %v, want %v", count, 0)
|
t.Errorf("count => %v, want %v", count, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -855,7 +841,6 @@ func TestConnBeginBatchDeferredError(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
|
pgxtest.SkipCockroachDB(t, conn, "Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
|
||||||
|
|
||||||
mustExec(t, conn, `create temporary table t (
|
mustExec(t, conn, `create temporary table t (
|
||||||
@ -894,7 +879,6 @@ func TestConnBeginBatchDeferredError(t *testing.T) {
|
|||||||
if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" {
|
if err, ok := err.(*pgconn.PgError); !ok || err.Code != "23505" {
|
||||||
t.Fatalf("expected error 23505, got %v", err)
|
t.Fatalf("expected error 23505, got %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -516,7 +516,6 @@ func multiInsert(conn *pgx.Conn, tableName string, columnNames []string, rowSrc
|
|||||||
}
|
}
|
||||||
|
|
||||||
return rowCount, nil
|
return rowCount, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) {
|
func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) {
|
||||||
@ -535,7 +534,8 @@ func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) {
|
|||||||
src := newBenchmarkWriteTableCopyFromSrc(n)
|
src := newBenchmarkWriteTableCopyFromSrc(n)
|
||||||
|
|
||||||
_, err := multiInsert(conn, "t",
|
_, err := multiInsert(conn, "t",
|
||||||
[]string{"varchar_1",
|
[]string{
|
||||||
|
"varchar_1",
|
||||||
"varchar_2",
|
"varchar_2",
|
||||||
"varchar_null_1",
|
"varchar_null_1",
|
||||||
"date_1",
|
"date_1",
|
||||||
@ -547,7 +547,8 @@ func benchmarkWriteNRowsViaMultiInsert(b *testing.B, n int) {
|
|||||||
"tstz_2",
|
"tstz_2",
|
||||||
"bool_1",
|
"bool_1",
|
||||||
"bool_2",
|
"bool_2",
|
||||||
"bool_3"},
|
"bool_3",
|
||||||
|
},
|
||||||
src)
|
src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
@ -568,7 +569,8 @@ func benchmarkWriteNRowsViaCopy(b *testing.B, n int) {
|
|||||||
|
|
||||||
_, err := conn.CopyFrom(context.Background(),
|
_, err := conn.CopyFrom(context.Background(),
|
||||||
pgx.Identifier{"t"},
|
pgx.Identifier{"t"},
|
||||||
[]string{"varchar_1",
|
[]string{
|
||||||
|
"varchar_1",
|
||||||
"varchar_2",
|
"varchar_2",
|
||||||
"varchar_null_1",
|
"varchar_null_1",
|
||||||
"date_1",
|
"date_1",
|
||||||
@ -580,7 +582,8 @@ func benchmarkWriteNRowsViaCopy(b *testing.B, n int) {
|
|||||||
"tstz_2",
|
"tstz_2",
|
||||||
"bool_1",
|
"bool_1",
|
||||||
"bool_2",
|
"bool_2",
|
||||||
"bool_3"},
|
"bool_3",
|
||||||
|
},
|
||||||
src)
|
src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
@ -611,6 +614,7 @@ func BenchmarkWrite5RowsViaInsert(b *testing.B) {
|
|||||||
func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) {
|
func BenchmarkWrite5RowsViaMultiInsert(b *testing.B) {
|
||||||
benchmarkWriteNRowsViaMultiInsert(b, 5)
|
benchmarkWriteNRowsViaMultiInsert(b, 5)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkWrite5RowsViaBatchInsert(b *testing.B) {
|
func BenchmarkWrite5RowsViaBatchInsert(b *testing.B) {
|
||||||
benchmarkWriteNRowsViaBatchInsert(b, 5)
|
benchmarkWriteNRowsViaBatchInsert(b, 5)
|
||||||
}
|
}
|
||||||
@ -626,6 +630,7 @@ func BenchmarkWrite10RowsViaInsert(b *testing.B) {
|
|||||||
func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) {
|
func BenchmarkWrite10RowsViaMultiInsert(b *testing.B) {
|
||||||
benchmarkWriteNRowsViaMultiInsert(b, 10)
|
benchmarkWriteNRowsViaMultiInsert(b, 10)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkWrite10RowsViaBatchInsert(b *testing.B) {
|
func BenchmarkWrite10RowsViaBatchInsert(b *testing.B) {
|
||||||
benchmarkWriteNRowsViaBatchInsert(b, 10)
|
benchmarkWriteNRowsViaBatchInsert(b, 10)
|
||||||
}
|
}
|
||||||
@ -641,6 +646,7 @@ func BenchmarkWrite100RowsViaInsert(b *testing.B) {
|
|||||||
func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) {
|
func BenchmarkWrite100RowsViaMultiInsert(b *testing.B) {
|
||||||
benchmarkWriteNRowsViaMultiInsert(b, 100)
|
benchmarkWriteNRowsViaMultiInsert(b, 100)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkWrite100RowsViaBatchInsert(b *testing.B) {
|
func BenchmarkWrite100RowsViaBatchInsert(b *testing.B) {
|
||||||
benchmarkWriteNRowsViaBatchInsert(b, 100)
|
benchmarkWriteNRowsViaBatchInsert(b, 100)
|
||||||
}
|
}
|
||||||
@ -672,6 +678,7 @@ func BenchmarkWrite10000RowsViaInsert(b *testing.B) {
|
|||||||
func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) {
|
func BenchmarkWrite10000RowsViaMultiInsert(b *testing.B) {
|
||||||
benchmarkWriteNRowsViaMultiInsert(b, 10000)
|
benchmarkWriteNRowsViaMultiInsert(b, 10000)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkWrite10000RowsViaBatchInsert(b *testing.B) {
|
func BenchmarkWrite10000RowsViaBatchInsert(b *testing.B) {
|
||||||
benchmarkWriteNRowsViaBatchInsert(b, 10000)
|
benchmarkWriteNRowsViaBatchInsert(b, 10000)
|
||||||
}
|
}
|
||||||
@ -1043,7 +1050,6 @@ func BenchmarkSelectRowsScanDecoder(b *testing.B) {
|
|||||||
}
|
}
|
||||||
for _, format := range formats {
|
for _, format := range formats {
|
||||||
b.Run(format.name, func(b *testing.B) {
|
b.Run(format.name, func(b *testing.B) {
|
||||||
|
|
||||||
br := &BenchRowDecoder{}
|
br := &BenchRowDecoder{}
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
rows, err := conn.Query(
|
rows, err := conn.Query(
|
||||||
|
6
conn.go
6
conn.go
@ -172,7 +172,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
|||||||
delete(config.RuntimeParams, "statement_cache_capacity")
|
delete(config.RuntimeParams, "statement_cache_capacity")
|
||||||
n, err := strconv.ParseInt(s, 10, 32)
|
n, err := strconv.ParseInt(s, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err)
|
return nil, pgconn.NewParseConfigError(connString, "cannot parse statement_cache_capacity", err)
|
||||||
}
|
}
|
||||||
statementCacheCapacity = int(n)
|
statementCacheCapacity = int(n)
|
||||||
}
|
}
|
||||||
@ -182,7 +182,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
|||||||
delete(config.RuntimeParams, "description_cache_capacity")
|
delete(config.RuntimeParams, "description_cache_capacity")
|
||||||
n, err := strconv.ParseInt(s, 10, 32)
|
n, err := strconv.ParseInt(s, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot parse description_cache_capacity: %w", err)
|
return nil, pgconn.NewParseConfigError(connString, "cannot parse description_cache_capacity", err)
|
||||||
}
|
}
|
||||||
descriptionCacheCapacity = int(n)
|
descriptionCacheCapacity = int(n)
|
||||||
}
|
}
|
||||||
@ -202,7 +202,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
|
|||||||
case "simple_protocol":
|
case "simple_protocol":
|
||||||
defaultQueryExecMode = QueryExecModeSimpleProtocol
|
defaultQueryExecMode = QueryExecModeSimpleProtocol
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("invalid default_query_exec_mode: %s", s)
|
return nil, pgconn.NewParseConfigError(connString, "invalid default_query_exec_mode", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
176
conn_test.go
176
conn_test.go
@ -412,7 +412,6 @@ func TestExecPerQuerySimpleProtocol(t *testing.T) {
|
|||||||
if commandTag.String() != "INSERT 0 1" {
|
if commandTag.String() != "INSERT 0 1" {
|
||||||
t.Fatalf("Unexpected results from Exec: %v", commandTag)
|
t.Fatalf("Unexpected results from Exec: %v", commandTag)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPrepare(t *testing.T) {
|
func TestPrepare(t *testing.T) {
|
||||||
@ -1089,7 +1088,7 @@ func TestLoadRangeType(t *testing.T) {
|
|||||||
conn.TypeMap().RegisterType(newRangeType)
|
conn.TypeMap().RegisterType(newRangeType)
|
||||||
conn.TypeMap().RegisterDefaultPgType(pgtype.Range[float64]{}, "examplefloatrange")
|
conn.TypeMap().RegisterDefaultPgType(pgtype.Range[float64]{}, "examplefloatrange")
|
||||||
|
|
||||||
var inputRangeType = pgtype.Range[float64]{
|
inputRangeType := pgtype.Range[float64]{
|
||||||
Lower: 1.0,
|
Lower: 1.0,
|
||||||
Upper: 2.0,
|
Upper: 2.0,
|
||||||
LowerType: pgtype.Inclusive,
|
LowerType: pgtype.Inclusive,
|
||||||
@ -1129,7 +1128,7 @@ func TestLoadMultiRangeType(t *testing.T) {
|
|||||||
conn.TypeMap().RegisterType(newMultiRangeType)
|
conn.TypeMap().RegisterType(newMultiRangeType)
|
||||||
conn.TypeMap().RegisterDefaultPgType(pgtype.Multirange[pgtype.Range[float64]]{}, "examplefloatmultirange")
|
conn.TypeMap().RegisterDefaultPgType(pgtype.Multirange[pgtype.Range[float64]]{}, "examplefloatmultirange")
|
||||||
|
|
||||||
var inputMultiRangeType = pgtype.Multirange[pgtype.Range[float64]]{
|
inputMultiRangeType := pgtype.Multirange[pgtype.Range[float64]]{
|
||||||
{
|
{
|
||||||
Lower: 1.0,
|
Lower: 1.0,
|
||||||
Upper: 2.0,
|
Upper: 2.0,
|
||||||
@ -1293,6 +1292,177 @@ func TestStmtCacheInvalidationTx(t *testing.T) {
|
|||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStmtCacheInvalidationConnWithBatch(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
if conn.PgConn().ParameterStatus("crdb_version") != "" {
|
||||||
|
t.Skip("Test fails due to different CRDB behavior")
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a table and fill it with some data
|
||||||
|
_, err := conn.Exec(ctx, `
|
||||||
|
DROP TABLE IF EXISTS drop_cols;
|
||||||
|
CREATE TABLE drop_cols (
|
||||||
|
id SERIAL PRIMARY KEY NOT NULL,
|
||||||
|
f1 int NOT NULL,
|
||||||
|
f2 int NOT NULL
|
||||||
|
);
|
||||||
|
`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = conn.Exec(ctx, "INSERT INTO drop_cols (f1, f2) VALUES (1, 2)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
getSQL := "SELECT * FROM drop_cols WHERE id = $1"
|
||||||
|
|
||||||
|
// This query will populate the statement cache. We don't care about the result.
|
||||||
|
rows, err := conn.Query(ctx, getSQL, 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
rows.Close()
|
||||||
|
require.NoError(t, rows.Err())
|
||||||
|
|
||||||
|
// Now, change the schema of the table out from under the statement, making it invalid.
|
||||||
|
_, err = conn.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// We must get an error the first time we try to re-execute a bad statement.
|
||||||
|
// It is up to the application to determine if it wants to try again. We punt to
|
||||||
|
// the application because there is no clear recovery path in the case of failed transactions
|
||||||
|
// or batch operations and because automatic retry is tricky and we don't want to get
|
||||||
|
// it wrong at such an importaint layer of the stack.
|
||||||
|
batch := &pgx.Batch{}
|
||||||
|
batch.Queue(getSQL, 1)
|
||||||
|
br := conn.SendBatch(ctx, batch)
|
||||||
|
rows, err = br.Query()
|
||||||
|
require.Error(t, err)
|
||||||
|
rows.Next()
|
||||||
|
nextErr := rows.Err()
|
||||||
|
rows.Close()
|
||||||
|
err = br.Close()
|
||||||
|
require.Error(t, err)
|
||||||
|
for _, err := range []error{nextErr, rows.Err()} {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal(`expected "cached plan must not change result type": no error`)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "cached plan must not change result type") {
|
||||||
|
t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// On retry, the statement should have been flushed from the cache.
|
||||||
|
batch = &pgx.Batch{}
|
||||||
|
batch.Queue(getSQL, 1)
|
||||||
|
br = conn.SendBatch(ctx, batch)
|
||||||
|
rows, err = br.Query()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rows.Next()
|
||||||
|
err = rows.Err()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rows.Close()
|
||||||
|
require.NoError(t, rows.Err())
|
||||||
|
err = br.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStmtCacheInvalidationTxWithBatch(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
defer closeConn(t, conn)
|
||||||
|
|
||||||
|
if conn.PgConn().ParameterStatus("crdb_version") != "" {
|
||||||
|
t.Skip("Server has non-standard prepare in errored transaction behavior (https://github.com/cockroachdb/cockroach/issues/84140)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a table and fill it with some data
|
||||||
|
_, err := conn.Exec(ctx, `
|
||||||
|
DROP TABLE IF EXISTS drop_cols;
|
||||||
|
CREATE TABLE drop_cols (
|
||||||
|
id SERIAL PRIMARY KEY NOT NULL,
|
||||||
|
f1 int NOT NULL,
|
||||||
|
f2 int NOT NULL
|
||||||
|
);
|
||||||
|
`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = conn.Exec(ctx, "INSERT INTO drop_cols (f1, f2) VALUES (1, 2)")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tx, err := conn.Begin(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
getSQL := "SELECT * FROM drop_cols WHERE id = $1"
|
||||||
|
|
||||||
|
// This query will populate the statement cache. We don't care about the result.
|
||||||
|
rows, err := tx.Query(ctx, getSQL, 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
rows.Close()
|
||||||
|
require.NoError(t, rows.Err())
|
||||||
|
|
||||||
|
// Now, change the schema of the table out from under the statement, making it invalid.
|
||||||
|
_, err = tx.Exec(ctx, "ALTER TABLE drop_cols DROP COLUMN f1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// We must get an error the first time we try to re-execute a bad statement.
|
||||||
|
// It is up to the application to determine if it wants to try again. We punt to
|
||||||
|
// the application because there is no clear recovery path in the case of failed transactions
|
||||||
|
// or batch operations and because automatic retry is tricky and we don't want to get
|
||||||
|
// it wrong at such an importaint layer of the stack.
|
||||||
|
batch := &pgx.Batch{}
|
||||||
|
batch.Queue(getSQL, 1)
|
||||||
|
br := tx.SendBatch(ctx, batch)
|
||||||
|
rows, err = br.Query()
|
||||||
|
require.Error(t, err)
|
||||||
|
rows.Next()
|
||||||
|
nextErr := rows.Err()
|
||||||
|
rows.Close()
|
||||||
|
err = br.Close()
|
||||||
|
require.Error(t, err)
|
||||||
|
for _, err := range []error{nextErr, rows.Err()} {
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal(`expected "cached plan must not change result type": no error`)
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "cached plan must not change result type") {
|
||||||
|
t.Fatalf(`expected "cached plan must not change result type", got: "%s"`, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
batch = &pgx.Batch{}
|
||||||
|
batch.Queue(getSQL, 1)
|
||||||
|
br = tx.SendBatch(ctx, batch)
|
||||||
|
rows, err = br.Query()
|
||||||
|
require.Error(t, err)
|
||||||
|
rows.Close()
|
||||||
|
err = rows.Err()
|
||||||
|
// Retries within the same transaction are errors (really anything except a rollback
|
||||||
|
// will be an error in this transaction).
|
||||||
|
require.Error(t, err)
|
||||||
|
rows.Close()
|
||||||
|
err = br.Close()
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
err = tx.Rollback(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// once we've rolled back, retries will work
|
||||||
|
batch = &pgx.Batch{}
|
||||||
|
batch.Queue(getSQL, 1)
|
||||||
|
br = conn.SendBatch(ctx, batch)
|
||||||
|
rows, err = br.Query()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rows.Next()
|
||||||
|
err = rows.Err()
|
||||||
|
require.NoError(t, err)
|
||||||
|
rows.Close()
|
||||||
|
err = br.Close()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ensureConnValid(t, conn)
|
||||||
|
}
|
||||||
|
|
||||||
func TestInsertDurationInterval(t *testing.T) {
|
func TestInsertDurationInterval(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -76,7 +76,6 @@ func TestConnCopyWithAllQueryExecModes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConnCopyWithKnownOIDQueryExecModes(t *testing.T) {
|
func TestConnCopyWithKnownOIDQueryExecModes(t *testing.T) {
|
||||||
|
|
||||||
for _, mode := range pgxtest.KnownOIDQueryExecModes {
|
for _, mode := range pgxtest.KnownOIDQueryExecModes {
|
||||||
t.Run(mode.String(), func(t *testing.T) {
|
t.Run(mode.String(), func(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
|
@ -31,7 +31,6 @@ func (c *LRUCache) Get(key string) *pgconn.StatementDescription {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or
|
// Put stores sd in the cache. Put panics if sd.SQL is "". Put does nothing if sd.SQL already exists in the cache or
|
||||||
|
@ -72,5 +72,4 @@ func testPgbouncer(t *testing.T, config *pgx.ConnConfig, workers, iterations int
|
|||||||
for i := 0; i < workers; i++ {
|
for i := 0; i < workers; i++ {
|
||||||
<-doneChan
|
<-doneChan
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -263,7 +263,7 @@ func computeClientProof(saltedPassword, authMessage []byte) []byte {
|
|||||||
return buf
|
return buf
|
||||||
}
|
}
|
||||||
|
|
||||||
func computeServerSignature(saltedPassword []byte, authMessage []byte) []byte {
|
func computeServerSignature(saltedPassword, authMessage []byte) []byte {
|
||||||
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
serverKey := computeHMAC(saltedPassword, []byte("Server Key"))
|
||||||
serverSignature := computeHMAC(serverKey, authMessage)
|
serverSignature := computeHMAC(serverKey, authMessage)
|
||||||
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
|
buf := make([]byte, base64.StdEncoding.EncodedLen(len(serverSignature)))
|
||||||
|
@ -78,7 +78,6 @@ func BenchmarkExec(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err = rr.Close()
|
_, err = rr.Close()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -127,7 +126,6 @@ func BenchmarkExecPossibleToCancel(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err = rr.Close()
|
_, err = rr.Close()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -184,7 +182,6 @@ func BenchmarkExecPrepared(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err = rr.Close()
|
_, err = rr.Close()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -227,7 +224,6 @@ func BenchmarkExecPreparedPossibleToCancel(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err = rr.Close()
|
_, err = rr.Close()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -23,9 +23,11 @@ import (
|
|||||||
"github.com/jackc/pgx/v5/pgproto3"
|
"github.com/jackc/pgx/v5/pgproto3"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
type (
|
||||||
type ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
AfterConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||||
type GetSSLPasswordFunc func(ctx context.Context) string
|
ValidateConnectFunc func(ctx context.Context, pgconn *PgConn) error
|
||||||
|
GetSSLPasswordFunc func(ctx context.Context) string
|
||||||
|
)
|
||||||
|
|
||||||
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A
|
// Config is the settings used to establish a connection to a PostgreSQL server. It must be created by [ParseConfig]. A
|
||||||
// manually initialized Config will cause ConnectConfig to panic.
|
// manually initialized Config will cause ConnectConfig to panic.
|
||||||
@ -179,7 +181,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
|||||||
//
|
//
|
||||||
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
|
// ParseConfig supports specifying multiple hosts in similar manner to libpq. Host and port may include comma separated
|
||||||
// values that will be tried in order. This can be used as part of a high availability system. See
|
// values that will be tried in order. This can be used as part of a high availability system. See
|
||||||
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
|
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
|
||||||
//
|
//
|
||||||
// # Example URL
|
// # Example URL
|
||||||
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
|
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
|
||||||
@ -206,9 +208,9 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
|||||||
// PGTARGETSESSIONATTRS
|
// PGTARGETSESSIONATTRS
|
||||||
// PGTZ
|
// PGTZ
|
||||||
//
|
//
|
||||||
// See http://www.postgresql.org/docs/11/static/libpq-envars.html for details on the meaning of environment variables.
|
// See http://www.postgresql.org/docs/current/static/libpq-envars.html for details on the meaning of environment variables.
|
||||||
//
|
//
|
||||||
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
|
// See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-PARAMKEYWORDS for parameter key word names. They are
|
||||||
// usually but not always the environment variable name downcased and without the "PG" prefix.
|
// usually but not always the environment variable name downcased and without the "PG" prefix.
|
||||||
//
|
//
|
||||||
// Important Security Notes:
|
// Important Security Notes:
|
||||||
@ -216,7 +218,7 @@ func NetworkAddress(host string, port uint16) (network, address string) {
|
|||||||
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
|
// ParseConfig tries to match libpq behavior with regard to PGSSLMODE. This includes defaulting to "prefer" behavior if
|
||||||
// not set.
|
// not set.
|
||||||
//
|
//
|
||||||
// See http://www.postgresql.org/docs/11/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
|
// See http://www.postgresql.org/docs/current/static/libpq-ssl.html#LIBPQ-SSL-PROTECTION for details on what level of
|
||||||
// security each sslmode provides.
|
// security each sslmode provides.
|
||||||
//
|
//
|
||||||
// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
|
// The sslmode "prefer" (the default), sslmode "allow", and multiple hosts are implemented via the Fallbacks field of
|
||||||
@ -713,7 +715,7 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
|||||||
// According to PostgreSQL documentation, if a root CA file exists,
|
// According to PostgreSQL documentation, if a root CA file exists,
|
||||||
// the behavior of sslmode=require should be the same as that of verify-ca
|
// the behavior of sslmode=require should be the same as that of verify-ca
|
||||||
//
|
//
|
||||||
// See https://www.postgresql.org/docs/12/libpq-ssl.html
|
// See https://www.postgresql.org/docs/current/libpq-ssl.html
|
||||||
if sslrootcert != "" {
|
if sslrootcert != "" {
|
||||||
goto nextCase
|
goto nextCase
|
||||||
}
|
}
|
||||||
@ -784,8 +786,8 @@ func configTLS(settings map[string]string, thisHost string, parseConfigOptions P
|
|||||||
if sslpassword != "" {
|
if sslpassword != "" {
|
||||||
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
decryptedKey, decryptedError = x509.DecryptPEMBlock(block, []byte(sslpassword))
|
||||||
}
|
}
|
||||||
//if sslpassword not provided or has decryption error when use it
|
// if sslpassword not provided or has decryption error when use it
|
||||||
//try to find sslpassword with callback function
|
// try to find sslpassword with callback function
|
||||||
if sslpassword == "" || decryptedError != nil {
|
if sslpassword == "" || decryptedError != nil {
|
||||||
if parseConfigOptions.GetSSLPassword != nil {
|
if parseConfigOptions.GetSSLPassword != nil {
|
||||||
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
|
sslpassword = parseConfigOptions.GetSSLPassword(context.Background())
|
||||||
|
@ -133,7 +133,6 @@ func TestParseConfig(t *testing.T) {
|
|||||||
name: "sslmode prefer",
|
name: "sslmode prefer",
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer",
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=prefer",
|
||||||
config: &pgconn.Config{
|
config: &pgconn.Config{
|
||||||
|
|
||||||
User: "jack",
|
User: "jack",
|
||||||
Password: "secret",
|
Password: "secret",
|
||||||
Host: "localhost",
|
Host: "localhost",
|
||||||
@ -567,7 +566,8 @@ func TestParseConfig(t *testing.T) {
|
|||||||
TLSConfig: &tls.Config{
|
TLSConfig: &tls.Config{
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
ServerName: "bar",
|
ServerName: "bar",
|
||||||
}},
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Host: "bar",
|
Host: "bar",
|
||||||
Port: defaultPort,
|
Port: defaultPort,
|
||||||
@ -579,7 +579,8 @@ func TestParseConfig(t *testing.T) {
|
|||||||
TLSConfig: &tls.Config{
|
TLSConfig: &tls.Config{
|
||||||
InsecureSkipVerify: true,
|
InsecureSkipVerify: true,
|
||||||
ServerName: "baz",
|
ServerName: "baz",
|
||||||
}},
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Host: "baz",
|
Host: "baz",
|
||||||
Port: defaultPort,
|
Port: defaultPort,
|
||||||
@ -1023,7 +1024,7 @@ func TestParseConfigReadsPgPassfile(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
tfName := filepath.Join(t.TempDir(), "config")
|
tfName := filepath.Join(t.TempDir(), "config")
|
||||||
err := os.WriteFile(tfName, []byte("test1:5432:curlydb:curly:nyuknyuknyuk"), 0600)
|
err := os.WriteFile(tfName, []byte("test1:5432:curlydb:curly:nyuknyuknyuk"), 0o600)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tfName)
|
connString := fmt.Sprintf("postgres://curly@test1:5432/curlydb?sslmode=disable&passfile=%s", tfName)
|
||||||
@ -1061,7 +1062,7 @@ host = def.example.com
|
|||||||
dbname = defdb
|
dbname = defdb
|
||||||
user = defuser
|
user = defuser
|
||||||
application_name = spaced string
|
application_name = spaced string
|
||||||
`), 0600)
|
`), 0o600)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
defaultPort := getDefaultPort(t)
|
defaultPort := getDefaultPort(t)
|
||||||
|
@ -27,7 +27,7 @@ func Timeout(err error) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// PgError represents an error reported by the PostgreSQL server. See
|
// PgError represents an error reported by the PostgreSQL server. See
|
||||||
// http://www.postgresql.org/docs/11/static/protocol-error-fields.html for
|
// http://www.postgresql.org/docs/current/static/protocol-error-fields.html for
|
||||||
// detailed field description.
|
// detailed field description.
|
||||||
type PgError struct {
|
type PgError struct {
|
||||||
Severity string
|
Severity string
|
||||||
@ -112,6 +112,14 @@ type ParseConfigError struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewParseConfigError(conn, msg string, err error) error {
|
||||||
|
return &ParseConfigError{
|
||||||
|
ConnString: conn,
|
||||||
|
msg: msg,
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (e *ParseConfigError) Error() string {
|
func (e *ParseConfigError) Error() string {
|
||||||
// Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only
|
// Now that ParseConfigError is public and ConnString is available to the developer, perhaps it would be better only
|
||||||
// return a static string. That would ensure that the error message cannot leak a password. The ConnString field would
|
// return a static string. That would ensure that the error message cannot leak a password. The ConnString field would
|
||||||
|
@ -1,11 +1,3 @@
|
|||||||
// File export_test exports some methods for better testing.
|
// File export_test exports some methods for better testing.
|
||||||
|
|
||||||
package pgconn
|
package pgconn
|
||||||
|
|
||||||
func NewParseConfigError(conn, msg string, err error) error {
|
|
||||||
return &ParseConfigError{
|
|
||||||
ConnString: conn,
|
|
||||||
msg: msg,
|
|
||||||
err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -28,7 +28,7 @@ func RegisterGSSProvider(newGSSArg NewGSSFunc) {
|
|||||||
|
|
||||||
// GSS provides GSSAPI authentication (e.g., Kerberos).
|
// GSS provides GSSAPI authentication (e.g., Kerberos).
|
||||||
type GSS interface {
|
type GSS interface {
|
||||||
GetInitToken(host string, service string) ([]byte, error)
|
GetInitToken(host, service string) ([]byte, error)
|
||||||
GetInitTokenFromSPN(spn string) ([]byte, error)
|
GetInitTokenFromSPN(spn string) ([]byte, error)
|
||||||
Continue(inToken []byte) (done bool, outToken []byte, err error)
|
Continue(inToken []byte) (done bool, outToken []byte, err error)
|
||||||
}
|
}
|
||||||
|
@ -135,7 +135,7 @@ func ConnectWithOptions(ctx context.Context, connString string, parseConfigOptio
|
|||||||
//
|
//
|
||||||
// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An
|
// If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An
|
||||||
// authentication error will terminate the chain of attempts (like libpq:
|
// authentication error will terminate the chain of attempts (like libpq:
|
||||||
// https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error.
|
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error.
|
||||||
func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) {
|
func ConnectConfig(ctx context.Context, config *Config) (*PgConn, error) {
|
||||||
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
|
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
|
||||||
// zero values.
|
// zero values.
|
||||||
@ -991,7 +991,8 @@ func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice {
|
|||||||
|
|
||||||
// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel
|
// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel
|
||||||
// request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there
|
// request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there
|
||||||
// is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9
|
// is no way to be sure a query was canceled.
|
||||||
|
// See https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-CANCELING-REQUESTS
|
||||||
func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
|
func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
|
||||||
// Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing
|
// Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing
|
||||||
// the connection config. This is important in high availability configurations where fallback connections may be
|
// the connection config. This is important in high availability configurations where fallback connections may be
|
||||||
@ -1140,7 +1141,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
|
|||||||
// binary format. If resultFormats is nil all results will be in text format.
|
// binary format. If resultFormats is nil all results will be in text format.
|
||||||
//
|
//
|
||||||
// ResultReader must be closed before PgConn can be used again.
|
// ResultReader must be closed before PgConn can be used again.
|
||||||
func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader {
|
func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) *ResultReader {
|
||||||
result := pgConn.execExtendedPrefix(ctx, paramValues)
|
result := pgConn.execExtendedPrefix(ctx, paramValues)
|
||||||
if result.closed {
|
if result.closed {
|
||||||
return result
|
return result
|
||||||
@ -1166,7 +1167,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
|
|||||||
// binary format. If resultFormats is nil all results will be in text format.
|
// binary format. If resultFormats is nil all results will be in text format.
|
||||||
//
|
//
|
||||||
// ResultReader must be closed before PgConn can be used again.
|
// ResultReader must be closed before PgConn can be used again.
|
||||||
func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader {
|
func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) *ResultReader {
|
||||||
result := pgConn.execExtendedPrefix(ctx, paramValues)
|
result := pgConn.execExtendedPrefix(ctx, paramValues)
|
||||||
if result.closed {
|
if result.closed {
|
||||||
return result
|
return result
|
||||||
@ -1373,7 +1374,14 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
|
|||||||
close(pgConn.cleanupDone)
|
close(pgConn.cleanupDone)
|
||||||
return CommandTag{}, normalizeTimeoutError(ctx, err)
|
return CommandTag{}, normalizeTimeoutError(ctx, err)
|
||||||
}
|
}
|
||||||
msg, _ := pgConn.receiveMessage()
|
// peekMessage never returns err in the bufferingReceive mode - it only forwards the bufferingReceive variables.
|
||||||
|
// Therefore, the only case for receiveMessage to return err is during handling of the ErrorResponse message type
|
||||||
|
// and using pgOnError handler to determine the connection is no longer valid (and thus closing the conn).
|
||||||
|
msg, serverError := pgConn.receiveMessage()
|
||||||
|
if serverError != nil {
|
||||||
|
close(abortCopyChan)
|
||||||
|
return CommandTag{}, serverError
|
||||||
|
}
|
||||||
|
|
||||||
switch msg := msg.(type) {
|
switch msg := msg.(type) {
|
||||||
case *pgproto3.ErrorResponse:
|
case *pgproto3.ErrorResponse:
|
||||||
@ -1712,7 +1720,7 @@ type Batch struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
|
// ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions.
|
||||||
func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
|
func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) {
|
||||||
if batch.err != nil {
|
if batch.err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -1725,7 +1733,7 @@ func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uin
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
|
// ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions.
|
||||||
func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
|
func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) {
|
||||||
if batch.err != nil {
|
if batch.err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -2201,7 +2209,7 @@ func (p *Pipeline) SendDeallocate(name string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendQueryParams is the pipeline version of *PgConn.QueryParams.
|
// SendQueryParams is the pipeline version of *PgConn.QueryParams.
|
||||||
func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) {
|
func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats, resultFormats []int16) {
|
||||||
if p.closed {
|
if p.closed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -2214,7 +2222,7 @@ func (p *Pipeline) SendQueryParams(sql string, paramValues [][]byte, paramOIDs [
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared.
|
// SendQueryPrepared is the pipeline version of *PgConn.QueryPrepared.
|
||||||
func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) {
|
func (p *Pipeline) SendQueryPrepared(stmtName string, paramValues [][]byte, paramFormats, resultFormats []int16) {
|
||||||
if p.closed {
|
if p.closed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@ import (
|
|||||||
func TestCommandTag(t *testing.T) {
|
func TestCommandTag(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
var tests = []struct {
|
tests := []struct {
|
||||||
commandTag CommandTag
|
commandTag CommandTag
|
||||||
rowsAffected int64
|
rowsAffected int64
|
||||||
isInsert bool
|
isInsert bool
|
||||||
|
@ -2130,6 +2130,63 @@ func TestConnCopyFromPrecanceled(t *testing.T) {
|
|||||||
ensureConnValid(t, pgConn)
|
ensureConnValid(t, pgConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://github.com/jackc/pgx/issues/2364
|
||||||
|
func TestConnCopyFromConnectionTerminated(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
||||||
|
t.Skip("Server does not support pg_terminate_backend")
|
||||||
|
}
|
||||||
|
|
||||||
|
closerConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
time.AfterFunc(500*time.Millisecond, func() {
|
||||||
|
// defer inside of AfterFunc instead of outer test function because outer function can finish while Read is still in
|
||||||
|
// progress which could cause closerConn to be closed too soon.
|
||||||
|
defer closeConn(t, closerConn)
|
||||||
|
err := closerConn.ExecParams(ctx, "select pg_terminate_backend($1)", [][]byte{[]byte(fmt.Sprintf("%d", pgConn.PID()))}, nil, nil, nil).Read().Err
|
||||||
|
require.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
||||||
|
a int4,
|
||||||
|
b varchar
|
||||||
|
)`).ReadAll()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
r, w := io.Pipe()
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 5_000; i++ {
|
||||||
|
a := strconv.Itoa(i)
|
||||||
|
b := "foo " + a + " bar"
|
||||||
|
_, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
time.Sleep(time.Millisecond)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
copySql := "COPY foo FROM STDIN WITH (FORMAT csv)"
|
||||||
|
ct, err := pgConn.CopyFrom(ctx, r, copySql)
|
||||||
|
assert.Equal(t, int64(0), ct.RowsAffected())
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
assert.True(t, pgConn.IsClosed())
|
||||||
|
select {
|
||||||
|
case <-pgConn.CleanupDone():
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("Connection cleanup exceeded maximum time")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnCopyFromGzipReader(t *testing.T) {
|
func TestConnCopyFromGzipReader(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -9,8 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required.
|
// AuthenticationCleartextPassword is a message sent from the backend indicating that a clear-text password is required.
|
||||||
type AuthenticationCleartextPassword struct {
|
type AuthenticationCleartextPassword struct{}
|
||||||
}
|
|
||||||
|
|
||||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
func (*AuthenticationCleartextPassword) Backend() {}
|
func (*AuthenticationCleartextPassword) Backend() {}
|
||||||
|
@ -9,8 +9,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// AuthenticationOk is a message sent from the backend indicating that authentication was successful.
|
// AuthenticationOk is a message sent from the backend indicating that authentication was successful.
|
||||||
type AuthenticationOk struct {
|
type AuthenticationOk struct{}
|
||||||
}
|
|
||||||
|
|
||||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
func (*AuthenticationOk) Backend() {}
|
func (*AuthenticationOk) Backend() {}
|
||||||
|
@ -4,8 +4,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
)
|
)
|
||||||
|
|
||||||
type CopyDone struct {
|
type CopyDone struct{}
|
||||||
}
|
|
||||||
|
|
||||||
// Backend identifies this message as sendable by the PostgreSQL backend.
|
// Backend identifies this message as sendable by the PostgreSQL backend.
|
||||||
func (*CopyDone) Backend() {}
|
func (*CopyDone) Backend() {}
|
||||||
|
@ -10,8 +10,7 @@ import (
|
|||||||
|
|
||||||
const gssEncReqNumber = 80877104
|
const gssEncReqNumber = 80877104
|
||||||
|
|
||||||
type GSSEncRequest struct {
|
type GSSEncRequest struct{}
|
||||||
}
|
|
||||||
|
|
||||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
func (*GSSEncRequest) Frontend() {}
|
func (*GSSEncRequest) Frontend() {}
|
||||||
|
@ -332,7 +332,7 @@ func TestJSONUnmarshalRowDescription(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestJSONUnmarshalBind(t *testing.T) {
|
func TestJSONUnmarshalBind(t *testing.T) {
|
||||||
var testCases = []struct {
|
testCases := []struct {
|
||||||
desc string
|
desc string
|
||||||
data []byte
|
data []byte
|
||||||
}{
|
}{
|
||||||
@ -348,7 +348,7 @@ func TestJSONUnmarshalBind(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
var want = Bind{
|
want := Bind{
|
||||||
PreparedStatement: "lrupsc_1_0",
|
PreparedStatement: "lrupsc_1_0",
|
||||||
ParameterFormatCodes: []int16{0},
|
ParameterFormatCodes: []int16{0},
|
||||||
Parameters: [][]byte{[]byte("ABC-123")},
|
Parameters: [][]byte{[]byte("ABC-123")},
|
||||||
|
@ -56,7 +56,6 @@ func (*RowDescription) Backend() {}
|
|||||||
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
// Decode decodes src into dst. src must contain the complete message with the exception of the initial 1 byte message
|
||||||
// type identifier and 4 byte message length.
|
// type identifier and 4 byte message length.
|
||||||
func (dst *RowDescription) Decode(src []byte) error {
|
func (dst *RowDescription) Decode(src []byte) error {
|
||||||
|
|
||||||
if len(src) < 2 {
|
if len(src) < 2 {
|
||||||
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
return &invalidMessageFormatErr{messageType: "RowDescription"}
|
||||||
}
|
}
|
||||||
|
@ -10,8 +10,7 @@ import (
|
|||||||
|
|
||||||
const sslRequestNumber = 80877103
|
const sslRequestNumber = 80877103
|
||||||
|
|
||||||
type SSLRequest struct {
|
type SSLRequest struct{}
|
||||||
}
|
|
||||||
|
|
||||||
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
// Frontend identifies this message as sendable by a PostgreSQL frontend.
|
||||||
func (*SSLRequest) Frontend() {}
|
func (*SSLRequest) Frontend() {}
|
||||||
|
@ -374,8 +374,8 @@ func quoteArrayElementIfNeeded(src string) string {
|
|||||||
return src
|
return src
|
||||||
}
|
}
|
||||||
|
|
||||||
// Array represents a PostgreSQL array for T. It implements the ArrayGetter and ArraySetter interfaces. It preserves
|
// Array represents a PostgreSQL array for T. It implements the [ArrayGetter] and [ArraySetter] interfaces. It preserves
|
||||||
// PostgreSQL dimensions and custom lower bounds. Use FlatArray if these are not needed.
|
// PostgreSQL dimensions and custom lower bounds. Use [FlatArray] if these are not needed.
|
||||||
type Array[T any] struct {
|
type Array[T any] struct {
|
||||||
Elements []T
|
Elements []T
|
||||||
Dims []ArrayDimension
|
Dims []ArrayDimension
|
||||||
@ -419,8 +419,8 @@ func (a Array[T]) ScanIndexType() any {
|
|||||||
return new(T)
|
return new(T)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FlatArray implements the ArrayGetter and ArraySetter interfaces for any slice of T. It ignores PostgreSQL dimensions
|
// FlatArray implements the [ArrayGetter] and [ArraySetter] interfaces for any slice of T. It ignores PostgreSQL dimensions
|
||||||
// and custom lower bounds. Use Array to preserve these.
|
// and custom lower bounds. Use [Array] to preserve these.
|
||||||
type FlatArray[T any] []T
|
type FlatArray[T any] []T
|
||||||
|
|
||||||
func (a FlatArray[T]) Dimensions() []ArrayDimension {
|
func (a FlatArray[T]) Dimensions() []ArrayDimension {
|
||||||
|
@ -256,7 +256,6 @@ func TestArrayCodecScanMultipleDimensions(t *testing.T) {
|
|||||||
skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)")
|
skipCockroachDB(t, "Server does not support nested arrays (https://github.com/cockroachdb/cockroach/issues/36815)")
|
||||||
|
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
rows, err := conn.Query(ctx, `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`)
|
rows, err := conn.Query(ctx, `select '{{1,2,3,4}, {5,6,7,8}, {9,10,11,12}}'::int4[]`)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
@ -23,16 +23,18 @@ type Bits struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanBits implements the [BitsScanner] interface.
|
||||||
func (b *Bits) ScanBits(v Bits) error {
|
func (b *Bits) ScanBits(v Bits) error {
|
||||||
*b = v
|
*b = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BitsValue implements the [BitsValuer] interface.
|
||||||
func (b Bits) BitsValue() (Bits, error) {
|
func (b Bits) BitsValue() (Bits, error) {
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Bits) Scan(src any) error {
|
func (dst *Bits) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = Bits{}
|
*dst = Bits{}
|
||||||
@ -47,7 +49,7 @@ func (dst *Bits) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Bits) Value() (driver.Value, error) {
|
func (src Bits) Value() (driver.Value, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -127,7 +129,6 @@ func (encodePlanBitsCodecText) Encode(value any, buf []byte) (newBuf []byte, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (BitsCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (BitsCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -22,16 +22,18 @@ type Bool struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanBool implements the [BoolScanner] interface.
|
||||||
func (b *Bool) ScanBool(v Bool) error {
|
func (b *Bool) ScanBool(v Bool) error {
|
||||||
*b = v
|
*b = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BoolValue implements the [BoolValuer] interface.
|
||||||
func (b Bool) BoolValue() (Bool, error) {
|
func (b Bool) BoolValue() (Bool, error) {
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Bool) Scan(src any) error {
|
func (dst *Bool) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = Bool{}
|
*dst = Bool{}
|
||||||
@ -61,7 +63,7 @@ func (dst *Bool) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Bool) Value() (driver.Value, error) {
|
func (src Bool) Value() (driver.Value, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -70,6 +72,7 @@ func (src Bool) Value() (driver.Value, error) {
|
|||||||
return src.Bool, nil
|
return src.Bool, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the [encoding/json.Marshaler] interface.
|
||||||
func (src Bool) MarshalJSON() ([]byte, error) {
|
func (src Bool) MarshalJSON() ([]byte, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
@ -82,6 +85,7 @@ func (src Bool) MarshalJSON() ([]byte, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
|
||||||
func (dst *Bool) UnmarshalJSON(b []byte) error {
|
func (dst *Bool) UnmarshalJSON(b []byte) error {
|
||||||
var v *bool
|
var v *bool
|
||||||
err := json.Unmarshal(b, &v)
|
err := json.Unmarshal(b, &v)
|
||||||
@ -200,7 +204,6 @@ func (encodePlanBoolCodecTextBool) Encode(value any, buf []byte) (newBuf []byte,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (BoolCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (BoolCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
@ -328,7 +331,7 @@ func (scanPlanTextAnyToBoolScanner) Scan(src []byte, dst any) error {
|
|||||||
return s.ScanBool(Bool{Bool: v, Valid: true})
|
return s.ScanBool(Bool{Bool: v, Valid: true})
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://www.postgresql.org/docs/11/datatype-boolean.html
|
// https://www.postgresql.org/docs/current/datatype-boolean.html
|
||||||
func planTextToBool(src []byte) (bool, error) {
|
func planTextToBool(src []byte) (bool, error) {
|
||||||
s := string(bytes.ToLower(bytes.TrimSpace(src)))
|
s := string(bytes.ToLower(bytes.TrimSpace(src)))
|
||||||
|
|
||||||
|
@ -24,16 +24,18 @@ type Box struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanBox implements the [BoxScanner] interface.
|
||||||
func (b *Box) ScanBox(v Box) error {
|
func (b *Box) ScanBox(v Box) error {
|
||||||
*b = v
|
*b = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BoxValue implements the [BoxValuer] interface.
|
||||||
func (b Box) BoxValue() (Box, error) {
|
func (b Box) BoxValue() (Box, error) {
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Box) Scan(src any) error {
|
func (dst *Box) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = Box{}
|
*dst = Box{}
|
||||||
@ -48,7 +50,7 @@ func (dst *Box) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Box) Value() (driver.Value, error) {
|
func (src Box) Value() (driver.Value, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -127,7 +129,6 @@ func (encodePlanBoxCodecText) Encode(value any, buf []byte) (newBuf []byte, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (BoxCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -527,6 +527,7 @@ func (w *netIPNetWrapper) ScanNetipPrefix(v netip.Prefix) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w netIPNetWrapper) NetipPrefixValue() (netip.Prefix, error) {
|
func (w netIPNetWrapper) NetipPrefixValue() (netip.Prefix, error) {
|
||||||
ip, ok := netip.AddrFromSlice(w.IP)
|
ip, ok := netip.AddrFromSlice(w.IP)
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -881,7 +882,6 @@ func (a *anyMultiDimSliceArray) SetDimensions(dimensions []ArrayDimension) error
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type, dimensions []ArrayDimension, flatSlice reflect.Value, flatSliceIdx int) reflect.Value {
|
func (a *anyMultiDimSliceArray) makeMultidimensionalSlice(sliceType reflect.Type, dimensions []ArrayDimension, flatSlice reflect.Value, flatSliceIdx int) reflect.Value {
|
||||||
|
@ -148,7 +148,6 @@ func (encodePlanBytesCodecTextBytesValuer) Encode(value any, buf []byte) (newBuf
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ByteaCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (ByteaCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -25,16 +25,18 @@ type Circle struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanCircle implements the [CircleScanner] interface.
|
||||||
func (c *Circle) ScanCircle(v Circle) error {
|
func (c *Circle) ScanCircle(v Circle) error {
|
||||||
*c = v
|
*c = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CircleValue implements the [CircleValuer] interface.
|
||||||
func (c Circle) CircleValue() (Circle, error) {
|
func (c Circle) CircleValue() (Circle, error) {
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Circle) Scan(src any) error {
|
func (dst *Circle) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = Circle{}
|
*dst = Circle{}
|
||||||
@ -49,7 +51,7 @@ func (dst *Circle) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Circle) Value() (driver.Value, error) {
|
func (src Circle) Value() (driver.Value, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -276,7 +276,6 @@ func (c *CompositeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byt
|
|||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown format code %d", format)
|
return nil, fmt.Errorf("unknown format code %d", format)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CompositeBinaryScanner struct {
|
type CompositeBinaryScanner struct {
|
||||||
|
@ -12,7 +12,6 @@ import (
|
|||||||
|
|
||||||
func TestCompositeCodecTranscode(t *testing.T) {
|
func TestCompositeCodecTranscode(t *testing.T) {
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
_, err := conn.Exec(ctx, `drop type if exists ct_test;
|
_, err := conn.Exec(ctx, `drop type if exists ct_test;
|
||||||
|
|
||||||
create type ct_test as (
|
create type ct_test as (
|
||||||
@ -90,7 +89,6 @@ func (p *point3d) ScanIndex(i int) any {
|
|||||||
|
|
||||||
func TestCompositeCodecTranscodeStruct(t *testing.T) {
|
func TestCompositeCodecTranscodeStruct(t *testing.T) {
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
_, err := conn.Exec(ctx, `drop type if exists point3d;
|
_, err := conn.Exec(ctx, `drop type if exists point3d;
|
||||||
|
|
||||||
create type point3d as (
|
create type point3d as (
|
||||||
@ -125,7 +123,6 @@ create type point3d as (
|
|||||||
|
|
||||||
func TestCompositeCodecTranscodeStructWrapper(t *testing.T) {
|
func TestCompositeCodecTranscodeStructWrapper(t *testing.T) {
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
_, err := conn.Exec(ctx, `drop type if exists point3d;
|
_, err := conn.Exec(ctx, `drop type if exists point3d;
|
||||||
|
|
||||||
create type point3d as (
|
create type point3d as (
|
||||||
@ -164,7 +161,6 @@ create type point3d as (
|
|||||||
|
|
||||||
func TestCompositeCodecDecodeValue(t *testing.T) {
|
func TestCompositeCodecDecodeValue(t *testing.T) {
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
_, err := conn.Exec(ctx, `drop type if exists point3d;
|
_, err := conn.Exec(ctx, `drop type if exists point3d;
|
||||||
|
|
||||||
create type point3d as (
|
create type point3d as (
|
||||||
@ -209,7 +205,6 @@ func TestCompositeCodecTranscodeStructWrapperForTable(t *testing.T) {
|
|||||||
skipCockroachDB(t, "Server does not support composite types from table definitions")
|
skipCockroachDB(t, "Server does not support composite types from table definitions")
|
||||||
|
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
_, err := conn.Exec(ctx, `drop table if exists point3d;
|
_, err := conn.Exec(ctx, `drop table if exists point3d;
|
||||||
|
|
||||||
create table point3d (
|
create table point3d (
|
||||||
|
@ -26,11 +26,13 @@ type Date struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanDate implements the [DateScanner] interface.
|
||||||
func (d *Date) ScanDate(v Date) error {
|
func (d *Date) ScanDate(v Date) error {
|
||||||
*d = v
|
*d = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DateValue implements the [DateValuer] interface.
|
||||||
func (d Date) DateValue() (Date, error) {
|
func (d Date) DateValue() (Date, error) {
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
@ -40,7 +42,7 @@ const (
|
|||||||
infinityDayOffset = 2147483647
|
infinityDayOffset = 2147483647
|
||||||
)
|
)
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Date) Scan(src any) error {
|
func (dst *Date) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = Date{}
|
*dst = Date{}
|
||||||
@ -58,7 +60,7 @@ func (dst *Date) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Date) Value() (driver.Value, error) {
|
func (src Date) Value() (driver.Value, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -70,6 +72,7 @@ func (src Date) Value() (driver.Value, error) {
|
|||||||
return src.Time, nil
|
return src.Time, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the [encoding/json.Marshaler] interface.
|
||||||
func (src Date) MarshalJSON() ([]byte, error) {
|
func (src Date) MarshalJSON() ([]byte, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
@ -89,6 +92,7 @@ func (src Date) MarshalJSON() ([]byte, error) {
|
|||||||
return json.Marshal(s)
|
return json.Marshal(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
|
||||||
func (dst *Date) UnmarshalJSON(b []byte) error {
|
func (dst *Date) UnmarshalJSON(b []byte) error {
|
||||||
var s *string
|
var s *string
|
||||||
err := json.Unmarshal(b, &s)
|
err := json.Unmarshal(b, &s)
|
||||||
@ -223,7 +227,6 @@ func (encodePlanDateCodecText) Encode(value any, buf []byte) (newBuf []byte, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (DateCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -10,7 +10,6 @@ import (
|
|||||||
|
|
||||||
func TestEnumCodec(t *testing.T) {
|
func TestEnumCodec(t *testing.T) {
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
_, err := conn.Exec(ctx, `drop type if exists enum_test;
|
_, err := conn.Exec(ctx, `drop type if exists enum_test;
|
||||||
|
|
||||||
create type enum_test as enum ('foo', 'bar', 'baz');`)
|
create type enum_test as enum ('foo', 'bar', 'baz');`)
|
||||||
@ -47,7 +46,6 @@ create type enum_test as enum ('foo', 'bar', 'baz');`)
|
|||||||
|
|
||||||
func TestEnumCodecValues(t *testing.T) {
|
func TestEnumCodecValues(t *testing.T) {
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
_, err := conn.Exec(ctx, `drop type if exists enum_test;
|
_, err := conn.Exec(ctx, `drop type if exists enum_test;
|
||||||
|
|
||||||
create type enum_test as enum ('foo', 'bar', 'baz');`)
|
create type enum_test as enum ('foo', 'bar', 'baz');`)
|
||||||
|
@ -16,26 +16,29 @@ type Float4 struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScanFloat64 implements the Float64Scanner interface.
|
// ScanFloat64 implements the [Float64Scanner] interface.
|
||||||
func (f *Float4) ScanFloat64(n Float8) error {
|
func (f *Float4) ScanFloat64(n Float8) error {
|
||||||
*f = Float4{Float32: float32(n.Float64), Valid: n.Valid}
|
*f = Float4{Float32: float32(n.Float64), Valid: n.Valid}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Float64Value implements the [Float64Valuer] interface.
|
||||||
func (f Float4) Float64Value() (Float8, error) {
|
func (f Float4) Float64Value() (Float8, error) {
|
||||||
return Float8{Float64: float64(f.Float32), Valid: f.Valid}, nil
|
return Float8{Float64: float64(f.Float32), Valid: f.Valid}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanInt64 implements the [Int64Scanner] interface.
|
||||||
func (f *Float4) ScanInt64(n Int8) error {
|
func (f *Float4) ScanInt64(n Int8) error {
|
||||||
*f = Float4{Float32: float32(n.Int64), Valid: n.Valid}
|
*f = Float4{Float32: float32(n.Int64), Valid: n.Valid}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Int64Value implements the [Int64Valuer] interface.
|
||||||
func (f Float4) Int64Value() (Int8, error) {
|
func (f Float4) Int64Value() (Int8, error) {
|
||||||
return Int8{Int64: int64(f.Float32), Valid: f.Valid}, nil
|
return Int8{Int64: int64(f.Float32), Valid: f.Valid}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (f *Float4) Scan(src any) error {
|
func (f *Float4) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*f = Float4{}
|
*f = Float4{}
|
||||||
@ -58,7 +61,7 @@ func (f *Float4) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (f Float4) Value() (driver.Value, error) {
|
func (f Float4) Value() (driver.Value, error) {
|
||||||
if !f.Valid {
|
if !f.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -66,6 +69,7 @@ func (f Float4) Value() (driver.Value, error) {
|
|||||||
return float64(f.Float32), nil
|
return float64(f.Float32), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the [encoding/json.Marshaler] interface.
|
||||||
func (f Float4) MarshalJSON() ([]byte, error) {
|
func (f Float4) MarshalJSON() ([]byte, error) {
|
||||||
if !f.Valid {
|
if !f.Valid {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
@ -73,6 +77,7 @@ func (f Float4) MarshalJSON() ([]byte, error) {
|
|||||||
return json.Marshal(f.Float32)
|
return json.Marshal(f.Float32)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
|
||||||
func (f *Float4) UnmarshalJSON(b []byte) error {
|
func (f *Float4) UnmarshalJSON(b []byte) error {
|
||||||
var n *float32
|
var n *float32
|
||||||
err := json.Unmarshal(b, &n)
|
err := json.Unmarshal(b, &n)
|
||||||
@ -170,7 +175,6 @@ func (encodePlanFloat4CodecBinaryInt64Valuer) Encode(value any, buf []byte) (new
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (Float4Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -24,26 +24,29 @@ type Float8 struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScanFloat64 implements the Float64Scanner interface.
|
// ScanFloat64 implements the [Float64Scanner] interface.
|
||||||
func (f *Float8) ScanFloat64(n Float8) error {
|
func (f *Float8) ScanFloat64(n Float8) error {
|
||||||
*f = n
|
*f = n
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Float64Value implements the [Float64Valuer] interface.
|
||||||
func (f Float8) Float64Value() (Float8, error) {
|
func (f Float8) Float64Value() (Float8, error) {
|
||||||
return f, nil
|
return f, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanInt64 implements the [Int64Scanner] interface.
|
||||||
func (f *Float8) ScanInt64(n Int8) error {
|
func (f *Float8) ScanInt64(n Int8) error {
|
||||||
*f = Float8{Float64: float64(n.Int64), Valid: n.Valid}
|
*f = Float8{Float64: float64(n.Int64), Valid: n.Valid}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Int64Value implements the [Int64Valuer] interface.
|
||||||
func (f Float8) Int64Value() (Int8, error) {
|
func (f Float8) Int64Value() (Int8, error) {
|
||||||
return Int8{Int64: int64(f.Float64), Valid: f.Valid}, nil
|
return Int8{Int64: int64(f.Float64), Valid: f.Valid}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (f *Float8) Scan(src any) error {
|
func (f *Float8) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*f = Float8{}
|
*f = Float8{}
|
||||||
@ -66,7 +69,7 @@ func (f *Float8) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (f Float8) Value() (driver.Value, error) {
|
func (f Float8) Value() (driver.Value, error) {
|
||||||
if !f.Valid {
|
if !f.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -74,6 +77,7 @@ func (f Float8) Value() (driver.Value, error) {
|
|||||||
return f.Float64, nil
|
return f.Float64, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the [encoding/json.Marshaler] interface.
|
||||||
func (f Float8) MarshalJSON() ([]byte, error) {
|
func (f Float8) MarshalJSON() ([]byte, error) {
|
||||||
if !f.Valid {
|
if !f.Valid {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
@ -81,6 +85,7 @@ func (f Float8) MarshalJSON() ([]byte, error) {
|
|||||||
return json.Marshal(f.Float64)
|
return json.Marshal(f.Float64)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
|
||||||
func (f *Float8) UnmarshalJSON(b []byte) error {
|
func (f *Float8) UnmarshalJSON(b []byte) error {
|
||||||
var n *float64
|
var n *float64
|
||||||
err := json.Unmarshal(b, &n)
|
err := json.Unmarshal(b, &n)
|
||||||
@ -208,7 +213,6 @@ func (encodePlanTextInt64Valuer) Encode(value any, buf []byte) (newBuf []byte, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (Float8Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -22,16 +22,18 @@ type HstoreValuer interface {
|
|||||||
// associated with its keys.
|
// associated with its keys.
|
||||||
type Hstore map[string]*string
|
type Hstore map[string]*string
|
||||||
|
|
||||||
|
// ScanHstore implements the [HstoreScanner] interface.
|
||||||
func (h *Hstore) ScanHstore(v Hstore) error {
|
func (h *Hstore) ScanHstore(v Hstore) error {
|
||||||
*h = v
|
*h = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// HstoreValue implements the [HstoreValuer] interface.
|
||||||
func (h Hstore) HstoreValue() (Hstore, error) {
|
func (h Hstore) HstoreValue() (Hstore, error) {
|
||||||
return h, nil
|
return h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (h *Hstore) Scan(src any) error {
|
func (h *Hstore) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*h = nil
|
*h = nil
|
||||||
@ -46,7 +48,7 @@ func (h *Hstore) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (h Hstore) Value() (driver.Value, error) {
|
func (h Hstore) Value() (driver.Value, error) {
|
||||||
if h == nil {
|
if h == nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -162,7 +164,6 @@ func (encodePlanHstoreCodecText) Encode(value any, buf []byte) (newBuf []byte, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (HstoreCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
@ -298,7 +299,7 @@ func (p *hstoreParser) consume() (b byte, end bool) {
|
|||||||
return b, false
|
return b, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func unexpectedByteErr(actualB byte, expectedB byte) error {
|
func unexpectedByteErr(actualB, expectedB byte) error {
|
||||||
return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB)
|
return fmt.Errorf("expected '%c' ('%#v'); found '%c' ('%#v')", expectedB, expectedB, actualB, actualB)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -316,7 +317,7 @@ func (p *hstoreParser) consumeExpectedByte(expectedB byte) error {
|
|||||||
|
|
||||||
// consumeExpected2 consumes two expected bytes or returns an error.
|
// consumeExpected2 consumes two expected bytes or returns an error.
|
||||||
// This was a bit faster than using a string argument (better inlining? Not sure).
|
// This was a bit faster than using a string argument (better inlining? Not sure).
|
||||||
func (p *hstoreParser) consumeExpected2(one byte, two byte) error {
|
func (p *hstoreParser) consumeExpected2(one, two byte) error {
|
||||||
if p.pos+2 > len(p.str) {
|
if p.pos+2 > len(p.str) {
|
||||||
return errors.New("unexpected end of string")
|
return errors.New("unexpected end of string")
|
||||||
}
|
}
|
||||||
|
@ -306,12 +306,13 @@ func TestRoundTrip(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkHstoreEncode(b *testing.B) {
|
func BenchmarkHstoreEncode(b *testing.B) {
|
||||||
h := pgtype.Hstore{"a x": stringPtr("100"), "b": stringPtr("200"), "c": stringPtr("300"),
|
h := pgtype.Hstore{
|
||||||
"d": stringPtr("400"), "e": stringPtr("500")}
|
"a x": stringPtr("100"), "b": stringPtr("200"), "c": stringPtr("300"),
|
||||||
|
"d": stringPtr("400"), "e": stringPtr("500"),
|
||||||
|
}
|
||||||
|
|
||||||
serializeConfigs := []struct {
|
serializeConfigs := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -24,7 +24,7 @@ type NetipPrefixValuer interface {
|
|||||||
NetipPrefixValue() (netip.Prefix, error)
|
NetipPrefixValue() (netip.Prefix, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// InetCodec handles both inet and cidr PostgreSQL types. The preferred Go types are netip.Prefix and netip.Addr. If
|
// InetCodec handles both inet and cidr PostgreSQL types. The preferred Go types are [netip.Prefix] and [netip.Addr]. If
|
||||||
// IsValid() is false then they are treated as SQL NULL.
|
// IsValid() is false then they are treated as SQL NULL.
|
||||||
type InetCodec struct{}
|
type InetCodec struct{}
|
||||||
|
|
||||||
@ -107,7 +107,6 @@ func (encodePlanInetCodecText) Encode(value any, buf []byte) (newBuf []byte, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (InetCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -26,7 +26,7 @@ type Int2 struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScanInt64 implements the Int64Scanner interface.
|
// ScanInt64 implements the [Int64Scanner] interface.
|
||||||
func (dst *Int2) ScanInt64(n Int8) error {
|
func (dst *Int2) ScanInt64(n Int8) error {
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
*dst = Int2{}
|
*dst = Int2{}
|
||||||
@ -44,11 +44,12 @@ func (dst *Int2) ScanInt64(n Int8) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Int64Value implements the [Int64Valuer] interface.
|
||||||
func (n Int2) Int64Value() (Int8, error) {
|
func (n Int2) Int64Value() (Int8, error) {
|
||||||
return Int8{Int64: int64(n.Int16), Valid: n.Valid}, nil
|
return Int8{Int64: int64(n.Int16), Valid: n.Valid}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Int2) Scan(src any) error {
|
func (dst *Int2) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = Int2{}
|
*dst = Int2{}
|
||||||
@ -87,7 +88,7 @@ func (dst *Int2) Scan(src any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Int2) Value() (driver.Value, error) {
|
func (src Int2) Value() (driver.Value, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -95,6 +96,7 @@ func (src Int2) Value() (driver.Value, error) {
|
|||||||
return int64(src.Int16), nil
|
return int64(src.Int16), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the [encoding/json.Marshaler] interface.
|
||||||
func (src Int2) MarshalJSON() ([]byte, error) {
|
func (src Int2) MarshalJSON() ([]byte, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
@ -102,6 +104,7 @@ func (src Int2) MarshalJSON() ([]byte, error) {
|
|||||||
return []byte(strconv.FormatInt(int64(src.Int16), 10)), nil
|
return []byte(strconv.FormatInt(int64(src.Int16), 10)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
|
||||||
func (dst *Int2) UnmarshalJSON(b []byte) error {
|
func (dst *Int2) UnmarshalJSON(b []byte) error {
|
||||||
var n *int16
|
var n *int16
|
||||||
err := json.Unmarshal(b, &n)
|
err := json.Unmarshal(b, &n)
|
||||||
@ -586,7 +589,7 @@ type Int4 struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScanInt64 implements the Int64Scanner interface.
|
// ScanInt64 implements the [Int64Scanner] interface.
|
||||||
func (dst *Int4) ScanInt64(n Int8) error {
|
func (dst *Int4) ScanInt64(n Int8) error {
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
*dst = Int4{}
|
*dst = Int4{}
|
||||||
@ -604,11 +607,12 @@ func (dst *Int4) ScanInt64(n Int8) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Int64Value implements the [Int64Valuer] interface.
|
||||||
func (n Int4) Int64Value() (Int8, error) {
|
func (n Int4) Int64Value() (Int8, error) {
|
||||||
return Int8{Int64: int64(n.Int32), Valid: n.Valid}, nil
|
return Int8{Int64: int64(n.Int32), Valid: n.Valid}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Int4) Scan(src any) error {
|
func (dst *Int4) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = Int4{}
|
*dst = Int4{}
|
||||||
@ -647,7 +651,7 @@ func (dst *Int4) Scan(src any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Int4) Value() (driver.Value, error) {
|
func (src Int4) Value() (driver.Value, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -655,6 +659,7 @@ func (src Int4) Value() (driver.Value, error) {
|
|||||||
return int64(src.Int32), nil
|
return int64(src.Int32), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the [encoding/json.Marshaler] interface.
|
||||||
func (src Int4) MarshalJSON() ([]byte, error) {
|
func (src Int4) MarshalJSON() ([]byte, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
@ -662,6 +667,7 @@ func (src Int4) MarshalJSON() ([]byte, error) {
|
|||||||
return []byte(strconv.FormatInt(int64(src.Int32), 10)), nil
|
return []byte(strconv.FormatInt(int64(src.Int32), 10)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
|
||||||
func (dst *Int4) UnmarshalJSON(b []byte) error {
|
func (dst *Int4) UnmarshalJSON(b []byte) error {
|
||||||
var n *int32
|
var n *int32
|
||||||
err := json.Unmarshal(b, &n)
|
err := json.Unmarshal(b, &n)
|
||||||
@ -1157,7 +1163,7 @@ type Int8 struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScanInt64 implements the Int64Scanner interface.
|
// ScanInt64 implements the [Int64Scanner] interface.
|
||||||
func (dst *Int8) ScanInt64(n Int8) error {
|
func (dst *Int8) ScanInt64(n Int8) error {
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
*dst = Int8{}
|
*dst = Int8{}
|
||||||
@ -1175,11 +1181,12 @@ func (dst *Int8) ScanInt64(n Int8) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Int64Value implements the [Int64Valuer] interface.
|
||||||
func (n Int8) Int64Value() (Int8, error) {
|
func (n Int8) Int64Value() (Int8, error) {
|
||||||
return Int8{Int64: int64(n.Int64), Valid: n.Valid}, nil
|
return Int8{Int64: int64(n.Int64), Valid: n.Valid}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Int8) Scan(src any) error {
|
func (dst *Int8) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = Int8{}
|
*dst = Int8{}
|
||||||
@ -1218,7 +1225,7 @@ func (dst *Int8) Scan(src any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Int8) Value() (driver.Value, error) {
|
func (src Int8) Value() (driver.Value, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -1226,6 +1233,7 @@ func (src Int8) Value() (driver.Value, error) {
|
|||||||
return int64(src.Int64), nil
|
return int64(src.Int64), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the [encoding/json.Marshaler] interface.
|
||||||
func (src Int8) MarshalJSON() ([]byte, error) {
|
func (src Int8) MarshalJSON() ([]byte, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
@ -1233,6 +1241,7 @@ func (src Int8) MarshalJSON() ([]byte, error) {
|
|||||||
return []byte(strconv.FormatInt(int64(src.Int64), 10)), nil
|
return []byte(strconv.FormatInt(int64(src.Int64), 10)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
|
||||||
func (dst *Int8) UnmarshalJSON(b []byte) error {
|
func (dst *Int8) UnmarshalJSON(b []byte) error {
|
||||||
var n *int64
|
var n *int64
|
||||||
err := json.Unmarshal(b, &n)
|
err := json.Unmarshal(b, &n)
|
||||||
|
@ -27,7 +27,7 @@ type Int<%= pg_byte_size %> struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScanInt64 implements the Int64Scanner interface.
|
// ScanInt64 implements the [Int64Scanner] interface.
|
||||||
func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error {
|
func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error {
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
*dst = Int<%= pg_byte_size %>{}
|
*dst = Int<%= pg_byte_size %>{}
|
||||||
@ -45,11 +45,12 @@ func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Int64Value implements the [Int64Valuer] interface.
|
||||||
func (n Int<%= pg_byte_size %>) Int64Value() (Int8, error) {
|
func (n Int<%= pg_byte_size %>) Int64Value() (Int8, error) {
|
||||||
return Int8{Int64: int64(n.Int<%= pg_bit_size %>), Valid: n.Valid}, nil
|
return Int8{Int64: int64(n.Int<%= pg_bit_size %>), Valid: n.Valid}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Int<%= pg_byte_size %>) Scan(src any) error {
|
func (dst *Int<%= pg_byte_size %>) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = Int<%= pg_byte_size %>{}
|
*dst = Int<%= pg_byte_size %>{}
|
||||||
@ -88,7 +89,7 @@ func (dst *Int<%= pg_byte_size %>) Scan(src any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) {
|
func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -96,6 +97,7 @@ func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) {
|
|||||||
return int64(src.Int<%= pg_bit_size %>), nil
|
return int64(src.Int<%= pg_bit_size %>), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the [encoding/json.Marshaler] interface.
|
||||||
func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) {
|
func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
@ -103,6 +105,7 @@ func (src Int<%= pg_byte_size %>) MarshalJSON() ([]byte, error) {
|
|||||||
return []byte(strconv.FormatInt(int64(src.Int<%= pg_bit_size %>), 10)), nil
|
return []byte(strconv.FormatInt(int64(src.Int<%= pg_bit_size %>), 10)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
|
||||||
func (dst *Int<%= pg_byte_size %>) UnmarshalJSON(b []byte) error {
|
func (dst *Int<%= pg_byte_size %>) UnmarshalJSON(b []byte) error {
|
||||||
var n *int<%= pg_bit_size %>
|
var n *int<%= pg_bit_size %>
|
||||||
err := json.Unmarshal(b, &n)
|
err := json.Unmarshal(b, &n)
|
||||||
|
@ -33,16 +33,18 @@ type Interval struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanInterval implements the [IntervalScanner] interface.
|
||||||
func (interval *Interval) ScanInterval(v Interval) error {
|
func (interval *Interval) ScanInterval(v Interval) error {
|
||||||
*interval = v
|
*interval = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IntervalValue implements the [IntervalValuer] interface.
|
||||||
func (interval Interval) IntervalValue() (Interval, error) {
|
func (interval Interval) IntervalValue() (Interval, error) {
|
||||||
return interval, nil
|
return interval, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (interval *Interval) Scan(src any) error {
|
func (interval *Interval) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*interval = Interval{}
|
*interval = Interval{}
|
||||||
@ -57,7 +59,7 @@ func (interval *Interval) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (interval Interval) Value() (driver.Value, error) {
|
func (interval Interval) Value() (driver.Value, error) {
|
||||||
if !interval.Valid {
|
if !interval.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -157,7 +159,6 @@ func (encodePlanIntervalCodecText) Encode(value any, buf []byte) (newBuf []byte,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (IntervalCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -95,7 +95,8 @@ func TestJSONBCodecCustomMarshal(t *testing.T) {
|
|||||||
Unmarshal: func(data []byte, v any) error {
|
Unmarshal: func(data []byte, v any) error {
|
||||||
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
|
return json.Unmarshal([]byte(`{"custom":"value"}`), v)
|
||||||
},
|
},
|
||||||
}})
|
},
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{
|
pgxtest.RunValueRoundTripTests(context.Background(), t, connTestRunner, pgxtest.KnownOIDQueryExecModes, "jsonb", []pgxtest.ValueRoundTripTest{
|
||||||
|
@ -24,11 +24,13 @@ type Line struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanLine implements the [LineScanner] interface.
|
||||||
func (line *Line) ScanLine(v Line) error {
|
func (line *Line) ScanLine(v Line) error {
|
||||||
*line = v
|
*line = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LineValue implements the [LineValuer] interface.
|
||||||
func (line Line) LineValue() (Line, error) {
|
func (line Line) LineValue() (Line, error) {
|
||||||
return line, nil
|
return line, nil
|
||||||
}
|
}
|
||||||
@ -37,7 +39,7 @@ func (line *Line) Set(src any) error {
|
|||||||
return fmt.Errorf("cannot convert %v to Line", src)
|
return fmt.Errorf("cannot convert %v to Line", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (line *Line) Scan(src any) error {
|
func (line *Line) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*line = Line{}
|
*line = Line{}
|
||||||
@ -52,7 +54,7 @@ func (line *Line) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (line Line) Value() (driver.Value, error) {
|
func (line Line) Value() (driver.Value, error) {
|
||||||
if !line.Valid {
|
if !line.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -129,7 +131,6 @@ func (encodePlanLineCodecText) Encode(value any, buf []byte) (newBuf []byte, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (LineCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -24,16 +24,18 @@ type Lseg struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanLseg implements the [LsegScanner] interface.
|
||||||
func (lseg *Lseg) ScanLseg(v Lseg) error {
|
func (lseg *Lseg) ScanLseg(v Lseg) error {
|
||||||
*lseg = v
|
*lseg = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LsegValue implements the [LsegValuer] interface.
|
||||||
func (lseg Lseg) LsegValue() (Lseg, error) {
|
func (lseg Lseg) LsegValue() (Lseg, error) {
|
||||||
return lseg, nil
|
return lseg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (lseg *Lseg) Scan(src any) error {
|
func (lseg *Lseg) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*lseg = Lseg{}
|
*lseg = Lseg{}
|
||||||
@ -48,7 +50,7 @@ func (lseg *Lseg) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (lseg Lseg) Value() (driver.Value, error) {
|
func (lseg Lseg) Value() (driver.Value, error) {
|
||||||
if !lseg.Valid {
|
if !lseg.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -127,7 +129,6 @@ func (encodePlanLsegCodecText) Encode(value any, buf []byte) (newBuf []byte, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (LsegCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -374,7 +374,6 @@ parseValueLoop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
return elements, nil
|
return elements, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseRange(buf *bytes.Buffer) (string, error) {
|
func parseRange(buf *bytes.Buffer) (string, error) {
|
||||||
@ -403,8 +402,8 @@ func parseRange(buf *bytes.Buffer) (string, error) {
|
|||||||
|
|
||||||
// Multirange is a generic multirange type.
|
// Multirange is a generic multirange type.
|
||||||
//
|
//
|
||||||
// T should implement RangeValuer and *T should implement RangeScanner. However, there does not appear to be a way to
|
// T should implement [RangeValuer] and *T should implement [RangeScanner]. However, there does not appear to be a way to
|
||||||
// enforce the RangeScanner constraint.
|
// enforce the [RangeScanner] constraint.
|
||||||
type Multirange[T RangeValuer] []T
|
type Multirange[T RangeValuer] []T
|
||||||
|
|
||||||
func (r Multirange[T]) IsNull() bool {
|
func (r Multirange[T]) IsNull() bool {
|
||||||
|
@ -71,7 +71,6 @@ func TestMultirangeCodecDecodeValue(t *testing.T) {
|
|||||||
skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)")
|
skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)")
|
||||||
|
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
for _, tt := range []struct {
|
for _, tt := range []struct {
|
||||||
sql string
|
sql string
|
||||||
expected any
|
expected any
|
||||||
|
@ -27,16 +27,20 @@ const (
|
|||||||
pgNumericNegInfSign = 0xf000
|
pgNumericNegInfSign = 0xf000
|
||||||
)
|
)
|
||||||
|
|
||||||
var big0 *big.Int = big.NewInt(0)
|
var (
|
||||||
var big1 *big.Int = big.NewInt(1)
|
big0 *big.Int = big.NewInt(0)
|
||||||
var big10 *big.Int = big.NewInt(10)
|
big1 *big.Int = big.NewInt(1)
|
||||||
var big100 *big.Int = big.NewInt(100)
|
big10 *big.Int = big.NewInt(10)
|
||||||
var big1000 *big.Int = big.NewInt(1000)
|
big100 *big.Int = big.NewInt(100)
|
||||||
|
big1000 *big.Int = big.NewInt(1000)
|
||||||
|
)
|
||||||
|
|
||||||
var bigNBase *big.Int = big.NewInt(nbase)
|
var (
|
||||||
var bigNBaseX2 *big.Int = big.NewInt(nbase * nbase)
|
bigNBase *big.Int = big.NewInt(nbase)
|
||||||
var bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase)
|
bigNBaseX2 *big.Int = big.NewInt(nbase * nbase)
|
||||||
var bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase)
|
bigNBaseX3 *big.Int = big.NewInt(nbase * nbase * nbase)
|
||||||
|
bigNBaseX4 *big.Int = big.NewInt(nbase * nbase * nbase * nbase)
|
||||||
|
)
|
||||||
|
|
||||||
type NumericScanner interface {
|
type NumericScanner interface {
|
||||||
ScanNumeric(v Numeric) error
|
ScanNumeric(v Numeric) error
|
||||||
@ -54,15 +58,18 @@ type Numeric struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanNumeric implements the [NumericScanner] interface.
|
||||||
func (n *Numeric) ScanNumeric(v Numeric) error {
|
func (n *Numeric) ScanNumeric(v Numeric) error {
|
||||||
*n = v
|
*n = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NumericValue implements the [NumericValuer] interface.
|
||||||
func (n Numeric) NumericValue() (Numeric, error) {
|
func (n Numeric) NumericValue() (Numeric, error) {
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Float64Value implements the [Float64Valuer] interface.
|
||||||
func (n Numeric) Float64Value() (Float8, error) {
|
func (n Numeric) Float64Value() (Float8, error) {
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
return Float8{}, nil
|
return Float8{}, nil
|
||||||
@ -92,6 +99,7 @@ func (n Numeric) Float64Value() (Float8, error) {
|
|||||||
return Float8{Float64: f, Valid: true}, nil
|
return Float8{Float64: f, Valid: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanInt64 implements the [Int64Scanner] interface.
|
||||||
func (n *Numeric) ScanInt64(v Int8) error {
|
func (n *Numeric) ScanInt64(v Int8) error {
|
||||||
if !v.Valid {
|
if !v.Valid {
|
||||||
*n = Numeric{}
|
*n = Numeric{}
|
||||||
@ -102,6 +110,7 @@ func (n *Numeric) ScanInt64(v Int8) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Int64Value implements the [Int64Valuer] interface.
|
||||||
func (n Numeric) Int64Value() (Int8, error) {
|
func (n Numeric) Int64Value() (Int8, error) {
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
return Int8{}, nil
|
return Int8{}, nil
|
||||||
@ -203,7 +212,7 @@ func nbaseDigitsToInt64(src []byte) (accum int64, bytesRead, digitsRead int) {
|
|||||||
return accum, rp, digits
|
return accum, rp, digits
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (n *Numeric) Scan(src any) error {
|
func (n *Numeric) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*n = Numeric{}
|
*n = Numeric{}
|
||||||
@ -218,7 +227,7 @@ func (n *Numeric) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (n Numeric) Value() (driver.Value, error) {
|
func (n Numeric) Value() (driver.Value, error) {
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -231,6 +240,7 @@ func (n Numeric) Value() (driver.Value, error) {
|
|||||||
return string(buf), err
|
return string(buf), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the [encoding/json.Marshaler] interface.
|
||||||
func (n Numeric) MarshalJSON() ([]byte, error) {
|
func (n Numeric) MarshalJSON() ([]byte, error) {
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
@ -243,6 +253,7 @@ func (n Numeric) MarshalJSON() ([]byte, error) {
|
|||||||
return n.numberTextBytes(), nil
|
return n.numberTextBytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
|
||||||
func (n *Numeric) UnmarshalJSON(src []byte) error {
|
func (n *Numeric) UnmarshalJSON(src []byte) error {
|
||||||
if bytes.Equal(src, []byte(`null`)) {
|
if bytes.Equal(src, []byte(`null`)) {
|
||||||
*n = Numeric{}
|
*n = Numeric{}
|
||||||
@ -553,7 +564,6 @@ func encodeNumericText(n Numeric, buf []byte) (newBuf []byte, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (NumericCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -198,7 +198,6 @@ func TestNumericMarshalJSON(t *testing.T) {
|
|||||||
skipCockroachDB(t, "server formats numeric text format differently")
|
skipCockroachDB(t, "server formats numeric text format differently")
|
||||||
|
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
for i, tt := range []struct {
|
for i, tt := range []struct {
|
||||||
decString string
|
decString string
|
||||||
}{
|
}{
|
||||||
|
@ -25,16 +25,18 @@ type Path struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanPath implements the [PathScanner] interface.
|
||||||
func (path *Path) ScanPath(v Path) error {
|
func (path *Path) ScanPath(v Path) error {
|
||||||
*path = v
|
*path = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PathValue implements the [PathValuer] interface.
|
||||||
func (path Path) PathValue() (Path, error) {
|
func (path Path) PathValue() (Path, error) {
|
||||||
return path, nil
|
return path, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (path *Path) Scan(src any) error {
|
func (path *Path) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*path = Path{}
|
*path = Path{}
|
||||||
@ -49,7 +51,7 @@ func (path *Path) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (path Path) Value() (driver.Value, error) {
|
func (path Path) Value() (driver.Value, error) {
|
||||||
if !path.Valid {
|
if !path.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -154,7 +156,6 @@ func (encodePlanPathCodecText) Encode(value any, buf []byte) (newBuf []byte, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (PathCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -2010,7 +2010,7 @@ var valuerReflectType = reflect.TypeFor[driver.Valuer]()
|
|||||||
|
|
||||||
// isNilDriverValuer returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement
|
// isNilDriverValuer returns true if value is any type of nil unless it implements driver.Valuer. *T is not considered to implement
|
||||||
// driver.Valuer if it is only implemented by T.
|
// driver.Valuer if it is only implemented by T.
|
||||||
func isNilDriverValuer(value any) (isNil bool, callNilDriverValuer bool) {
|
func isNilDriverValuer(value any) (isNil, callNilDriverValuer bool) {
|
||||||
if value == nil {
|
if value == nil {
|
||||||
return true, false
|
return true, false
|
||||||
}
|
}
|
||||||
|
@ -34,17 +34,19 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Test for renamed types
|
// Test for renamed types
|
||||||
type _string string
|
type (
|
||||||
type _bool bool
|
_string string
|
||||||
type _uint8 uint8
|
_bool bool
|
||||||
type _int8 int8
|
_uint8 uint8
|
||||||
type _int16 int16
|
_int8 int8
|
||||||
type _int16Slice []int16
|
_int16 int16
|
||||||
type _int32Slice []int32
|
_int16Slice []int16
|
||||||
type _int64Slice []int64
|
_int32Slice []int32
|
||||||
type _float32Slice []float32
|
_int64Slice []int64
|
||||||
type _float64Slice []float64
|
_float32Slice []float32
|
||||||
type _byteSlice []byte
|
_float64Slice []float64
|
||||||
|
_byteSlice []byte
|
||||||
|
)
|
||||||
|
|
||||||
// unregisteredOID represents an actual type that is not registered. Cannot use 0 because that represents that the type
|
// unregisteredOID represents an actual type that is not registered. Cannot use 0 because that represents that the type
|
||||||
// is not known (e.g. when using the simple protocol).
|
// is not known (e.g. when using the simple protocol).
|
||||||
@ -530,7 +532,8 @@ func TestMapEncodePlanCacheUUIDTypeConfusion(t *testing.T) {
|
|||||||
0, 0, 0, 16,
|
0, 0, 0, 16,
|
||||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
||||||
0, 0, 0, 16,
|
0, 0, 0, 16,
|
||||||
15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}
|
15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
|
||||||
|
}
|
||||||
|
|
||||||
m := pgtype.NewMap()
|
m := pgtype.NewMap()
|
||||||
buf, err := m.Encode(pgtype.UUIDArrayOID, pgtype.BinaryFormatCode,
|
buf, err := m.Encode(pgtype.UUIDArrayOID, pgtype.BinaryFormatCode,
|
||||||
|
@ -30,11 +30,13 @@ type Point struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanPoint implements the [PointScanner] interface.
|
||||||
func (p *Point) ScanPoint(v Point) error {
|
func (p *Point) ScanPoint(v Point) error {
|
||||||
*p = v
|
*p = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PointValue implements the [PointValuer] interface.
|
||||||
func (p Point) PointValue() (Point, error) {
|
func (p Point) PointValue() (Point, error) {
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
@ -68,7 +70,7 @@ func parsePoint(src []byte) (*Point, error) {
|
|||||||
return &Point{P: Vec2{x, y}, Valid: true}, nil
|
return &Point{P: Vec2{x, y}, Valid: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Point) Scan(src any) error {
|
func (dst *Point) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = Point{}
|
*dst = Point{}
|
||||||
@ -83,7 +85,7 @@ func (dst *Point) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Point) Value() (driver.Value, error) {
|
func (src Point) Value() (driver.Value, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -96,6 +98,7 @@ func (src Point) Value() (driver.Value, error) {
|
|||||||
return string(buf), err
|
return string(buf), err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the [encoding/json.Marshaler] interface.
|
||||||
func (src Point) MarshalJSON() ([]byte, error) {
|
func (src Point) MarshalJSON() ([]byte, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
@ -108,6 +111,7 @@ func (src Point) MarshalJSON() ([]byte, error) {
|
|||||||
return buff.Bytes(), nil
|
return buff.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
|
||||||
func (dst *Point) UnmarshalJSON(point []byte) error {
|
func (dst *Point) UnmarshalJSON(point []byte) error {
|
||||||
p, err := parsePoint(point)
|
p, err := parsePoint(point)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -178,7 +182,6 @@ func (encodePlanPointCodecText) Encode(value any, buf []byte) (newBuf []byte, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (PointCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -24,16 +24,18 @@ type Polygon struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanPolygon implements the [PolygonScanner] interface.
|
||||||
func (p *Polygon) ScanPolygon(v Polygon) error {
|
func (p *Polygon) ScanPolygon(v Polygon) error {
|
||||||
*p = v
|
*p = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PolygonValue implements the [PolygonValuer] interface.
|
||||||
func (p Polygon) PolygonValue() (Polygon, error) {
|
func (p Polygon) PolygonValue() (Polygon, error) {
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (p *Polygon) Scan(src any) error {
|
func (p *Polygon) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*p = Polygon{}
|
*p = Polygon{}
|
||||||
@ -48,7 +50,7 @@ func (p *Polygon) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (p Polygon) Value() (driver.Value, error) {
|
func (p Polygon) Value() (driver.Value, error) {
|
||||||
if !p.Valid {
|
if !p.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -139,7 +141,6 @@ func (encodePlanPolygonCodecText) Encode(value any, buf []byte) (newBuf []byte,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (PolygonCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -191,11 +191,13 @@ type untypedBinaryRange struct {
|
|||||||
// 18 = [ = 10010
|
// 18 = [ = 10010
|
||||||
// 24 = = 11000
|
// 24 = = 11000
|
||||||
|
|
||||||
const emptyMask = 1
|
const (
|
||||||
const lowerInclusiveMask = 2
|
emptyMask = 1
|
||||||
const upperInclusiveMask = 4
|
lowerInclusiveMask = 2
|
||||||
const lowerUnboundedMask = 8
|
upperInclusiveMask = 4
|
||||||
const upperUnboundedMask = 16
|
lowerUnboundedMask = 8
|
||||||
|
upperUnboundedMask = 16
|
||||||
|
)
|
||||||
|
|
||||||
func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) {
|
func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) {
|
||||||
ubr := &untypedBinaryRange{}
|
ubr := &untypedBinaryRange{}
|
||||||
@ -273,7 +275,6 @@ func parseUntypedBinaryRange(src []byte) (*untypedBinaryRange, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return ubr, nil
|
return ubr, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Range is a generic range type.
|
// Range is a generic range type.
|
||||||
|
@ -75,7 +75,6 @@ func TestRangeCodecScanRangeTwiceWithUnbounded(t *testing.T) {
|
|||||||
skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)")
|
skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)")
|
||||||
|
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
var r pgtype.Range[pgtype.Int4]
|
var r pgtype.Range[pgtype.Int4]
|
||||||
|
|
||||||
err := conn.QueryRow(context.Background(), `select '[1,5)'::int4range`).Scan(&r)
|
err := conn.QueryRow(context.Background(), `select '[1,5)'::int4range`).Scan(&r)
|
||||||
@ -129,7 +128,6 @@ func TestRangeCodecDecodeValue(t *testing.T) {
|
|||||||
skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)")
|
skipCockroachDB(t, "Server does not support range types (see https://github.com/cockroachdb/cockroach/issues/27791)")
|
||||||
|
|
||||||
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) {
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, _ testing.TB, conn *pgx.Conn) {
|
||||||
|
|
||||||
for _, tt := range []struct {
|
for _, tt := range []struct {
|
||||||
sql string
|
sql string
|
||||||
expected any
|
expected any
|
||||||
|
@ -121,5 +121,4 @@ func (RecordCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (an
|
|||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unknown format code %d", format)
|
return nil, fmt.Errorf("unknown format code %d", format)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -19,16 +19,18 @@ type Text struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanText implements the [TextScanner] interface.
|
||||||
func (t *Text) ScanText(v Text) error {
|
func (t *Text) ScanText(v Text) error {
|
||||||
*t = v
|
*t = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TextValue implements the [TextValuer] interface.
|
||||||
func (t Text) TextValue() (Text, error) {
|
func (t Text) TextValue() (Text, error) {
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Text) Scan(src any) error {
|
func (dst *Text) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = Text{}
|
*dst = Text{}
|
||||||
@ -47,7 +49,7 @@ func (dst *Text) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Text) Value() (driver.Value, error) {
|
func (src Text) Value() (driver.Value, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -55,6 +57,7 @@ func (src Text) Value() (driver.Value, error) {
|
|||||||
return src.String, nil
|
return src.String, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the [encoding/json.Marshaler] interface.
|
||||||
func (src Text) MarshalJSON() ([]byte, error) {
|
func (src Text) MarshalJSON() ([]byte, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
@ -63,6 +66,7 @@ func (src Text) MarshalJSON() ([]byte, error) {
|
|||||||
return json.Marshal(src.String)
|
return json.Marshal(src.String)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
|
||||||
func (dst *Text) UnmarshalJSON(b []byte) error {
|
func (dst *Text) UnmarshalJSON(b []byte) error {
|
||||||
var s *string
|
var s *string
|
||||||
err := json.Unmarshal(b, &s)
|
err := json.Unmarshal(b, &s)
|
||||||
@ -146,7 +150,6 @@ func (encodePlanTextCodecTextValuer) Encode(value any, buf []byte) (newBuf []byt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (TextCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case TextFormatCode, BinaryFormatCode:
|
case TextFormatCode, BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -35,16 +35,18 @@ type TID struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanTID implements the [TIDScanner] interface.
|
||||||
func (b *TID) ScanTID(v TID) error {
|
func (b *TID) ScanTID(v TID) error {
|
||||||
*b = v
|
*b = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TIDValue implements the [TIDValuer] interface.
|
||||||
func (b TID) TIDValue() (TID, error) {
|
func (b TID) TIDValue() (TID, error) {
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *TID) Scan(src any) error {
|
func (dst *TID) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = TID{}
|
*dst = TID{}
|
||||||
@ -59,7 +61,7 @@ func (dst *TID) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src TID) Value() (driver.Value, error) {
|
func (src TID) Value() (driver.Value, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -131,7 +133,6 @@ func (encodePlanTIDCodecText) Encode(value any, buf []byte) (newBuf []byte, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (TIDCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -29,16 +29,18 @@ type Time struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanTime implements the [TimeScanner] interface.
|
||||||
func (t *Time) ScanTime(v Time) error {
|
func (t *Time) ScanTime(v Time) error {
|
||||||
*t = v
|
*t = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TimeValue implements the [TimeValuer] interface.
|
||||||
func (t Time) TimeValue() (Time, error) {
|
func (t Time) TimeValue() (Time, error) {
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (t *Time) Scan(src any) error {
|
func (t *Time) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*t = Time{}
|
*t = Time{}
|
||||||
@ -58,7 +60,7 @@ func (t *Time) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (t Time) Value() (driver.Value, error) {
|
func (t Time) Value() (driver.Value, error) {
|
||||||
if !t.Valid {
|
if !t.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -137,7 +139,6 @@ func (encodePlanTimeCodecText) Encode(value any, buf []byte) (newBuf []byte, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (TimeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -11,8 +11,10 @@ import (
|
|||||||
"github.com/jackc/pgx/v5/internal/pgio"
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
)
|
)
|
||||||
|
|
||||||
const pgTimestampFormat = "2006-01-02 15:04:05.999999999"
|
const (
|
||||||
const jsonISO8601 = "2006-01-02T15:04:05.999999999"
|
pgTimestampFormat = "2006-01-02 15:04:05.999999999"
|
||||||
|
jsonISO8601 = "2006-01-02T15:04:05.999999999"
|
||||||
|
)
|
||||||
|
|
||||||
type TimestampScanner interface {
|
type TimestampScanner interface {
|
||||||
ScanTimestamp(v Timestamp) error
|
ScanTimestamp(v Timestamp) error
|
||||||
@ -29,16 +31,18 @@ type Timestamp struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanTimestamp implements the [TimestampScanner] interface.
|
||||||
func (ts *Timestamp) ScanTimestamp(v Timestamp) error {
|
func (ts *Timestamp) ScanTimestamp(v Timestamp) error {
|
||||||
*ts = v
|
*ts = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TimestampValue implements the [TimestampValuer] interface.
|
||||||
func (ts Timestamp) TimestampValue() (Timestamp, error) {
|
func (ts Timestamp) TimestampValue() (Timestamp, error) {
|
||||||
return ts, nil
|
return ts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (ts *Timestamp) Scan(src any) error {
|
func (ts *Timestamp) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*ts = Timestamp{}
|
*ts = Timestamp{}
|
||||||
@ -56,7 +60,7 @@ func (ts *Timestamp) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (ts Timestamp) Value() (driver.Value, error) {
|
func (ts Timestamp) Value() (driver.Value, error) {
|
||||||
if !ts.Valid {
|
if !ts.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -68,6 +72,7 @@ func (ts Timestamp) Value() (driver.Value, error) {
|
|||||||
return ts.Time, nil
|
return ts.Time, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the [encoding/json.Marshaler] interface.
|
||||||
func (ts Timestamp) MarshalJSON() ([]byte, error) {
|
func (ts Timestamp) MarshalJSON() ([]byte, error) {
|
||||||
if !ts.Valid {
|
if !ts.Valid {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
@ -87,6 +92,7 @@ func (ts Timestamp) MarshalJSON() ([]byte, error) {
|
|||||||
return json.Marshal(s)
|
return json.Marshal(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
|
||||||
func (ts *Timestamp) UnmarshalJSON(b []byte) error {
|
func (ts *Timestamp) UnmarshalJSON(b []byte) error {
|
||||||
var s *string
|
var s *string
|
||||||
err := json.Unmarshal(b, &s)
|
err := json.Unmarshal(b, &s)
|
||||||
|
@ -102,7 +102,6 @@ func TestTimestampCodecDecodeTextInvalid(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestTimestampMarshalJSON(t *testing.T) {
|
func TestTimestampMarshalJSON(t *testing.T) {
|
||||||
|
|
||||||
tsStruct := struct {
|
tsStruct := struct {
|
||||||
TS pgtype.Timestamp `json:"ts"`
|
TS pgtype.Timestamp `json:"ts"`
|
||||||
}{}
|
}{}
|
||||||
|
@ -11,10 +11,12 @@ import (
|
|||||||
"github.com/jackc/pgx/v5/internal/pgio"
|
"github.com/jackc/pgx/v5/internal/pgio"
|
||||||
)
|
)
|
||||||
|
|
||||||
const pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07"
|
const (
|
||||||
const pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00"
|
pgTimestamptzHourFormat = "2006-01-02 15:04:05.999999999Z07"
|
||||||
const pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00"
|
pgTimestamptzMinuteFormat = "2006-01-02 15:04:05.999999999Z07:00"
|
||||||
const microsecFromUnixEpochToY2K = 946684800 * 1000000
|
pgTimestamptzSecondFormat = "2006-01-02 15:04:05.999999999Z07:00:00"
|
||||||
|
microsecFromUnixEpochToY2K = 946684800 * 1000000
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
negativeInfinityMicrosecondOffset = -9223372036854775808
|
negativeInfinityMicrosecondOffset = -9223372036854775808
|
||||||
@ -36,16 +38,18 @@ type Timestamptz struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanTimestamptz implements the [TimestamptzScanner] interface.
|
||||||
func (tstz *Timestamptz) ScanTimestamptz(v Timestamptz) error {
|
func (tstz *Timestamptz) ScanTimestamptz(v Timestamptz) error {
|
||||||
*tstz = v
|
*tstz = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TimestamptzValue implements the [TimestamptzValuer] interface.
|
||||||
func (tstz Timestamptz) TimestamptzValue() (Timestamptz, error) {
|
func (tstz Timestamptz) TimestamptzValue() (Timestamptz, error) {
|
||||||
return tstz, nil
|
return tstz, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (tstz *Timestamptz) Scan(src any) error {
|
func (tstz *Timestamptz) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*tstz = Timestamptz{}
|
*tstz = Timestamptz{}
|
||||||
@ -63,7 +67,7 @@ func (tstz *Timestamptz) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (tstz Timestamptz) Value() (driver.Value, error) {
|
func (tstz Timestamptz) Value() (driver.Value, error) {
|
||||||
if !tstz.Valid {
|
if !tstz.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -75,6 +79,7 @@ func (tstz Timestamptz) Value() (driver.Value, error) {
|
|||||||
return tstz.Time, nil
|
return tstz.Time, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the [encoding/json.Marshaler] interface.
|
||||||
func (tstz Timestamptz) MarshalJSON() ([]byte, error) {
|
func (tstz Timestamptz) MarshalJSON() ([]byte, error) {
|
||||||
if !tstz.Valid {
|
if !tstz.Valid {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
@ -94,6 +99,7 @@ func (tstz Timestamptz) MarshalJSON() ([]byte, error) {
|
|||||||
return json.Marshal(s)
|
return json.Marshal(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
|
||||||
func (tstz *Timestamptz) UnmarshalJSON(b []byte) error {
|
func (tstz *Timestamptz) UnmarshalJSON(b []byte) error {
|
||||||
var s *string
|
var s *string
|
||||||
err := json.Unmarshal(b, &s)
|
err := json.Unmarshal(b, &s)
|
||||||
@ -225,7 +231,6 @@ func (encodePlanTimestamptzCodecText) Encode(value any, buf []byte) (newBuf []by
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (c *TimestamptzCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -3,6 +3,7 @@ package pgtype
|
|||||||
import (
|
import (
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -24,16 +25,18 @@ type Uint32 struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanUint32 implements the [Uint32Scanner] interface.
|
||||||
func (n *Uint32) ScanUint32(v Uint32) error {
|
func (n *Uint32) ScanUint32(v Uint32) error {
|
||||||
*n = v
|
*n = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Uint32Value implements the [Uint32Valuer] interface.
|
||||||
func (n Uint32) Uint32Value() (Uint32, error) {
|
func (n Uint32) Uint32Value() (Uint32, error) {
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Uint32) Scan(src any) error {
|
func (dst *Uint32) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = Uint32{}
|
*dst = Uint32{}
|
||||||
@ -67,7 +70,7 @@ func (dst *Uint32) Scan(src any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Uint32) Value() (driver.Value, error) {
|
func (src Uint32) Value() (driver.Value, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -75,6 +78,31 @@ func (src Uint32) Value() (driver.Value, error) {
|
|||||||
return int64(src.Uint32), nil
|
return int64(src.Uint32), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the [encoding/json.Marshaler] interface.
|
||||||
|
func (src Uint32) MarshalJSON() ([]byte, error) {
|
||||||
|
if !src.Valid {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return json.Marshal(src.Uint32)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
|
||||||
|
func (dst *Uint32) UnmarshalJSON(b []byte) error {
|
||||||
|
var n *uint32
|
||||||
|
err := json.Unmarshal(b, &n)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if n == nil {
|
||||||
|
*dst = Uint32{}
|
||||||
|
} else {
|
||||||
|
*dst = Uint32{Uint32: *n, Valid: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type Uint32Codec struct{}
|
type Uint32Codec struct{}
|
||||||
|
|
||||||
func (Uint32Codec) FormatSupported(format int16) bool {
|
func (Uint32Codec) FormatSupported(format int16) bool {
|
||||||
@ -197,7 +225,6 @@ func (encodePlanUint32CodecTextInt64Valuer) Encode(value any, buf []byte) (newBu
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (Uint32Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -24,16 +24,18 @@ type Uint64 struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanUint64 implements the [Uint64Scanner] interface.
|
||||||
func (n *Uint64) ScanUint64(v Uint64) error {
|
func (n *Uint64) ScanUint64(v Uint64) error {
|
||||||
*n = v
|
*n = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Uint64Value implements the [Uint64Valuer] interface.
|
||||||
func (n Uint64) Uint64Value() (Uint64, error) {
|
func (n Uint64) Uint64Value() (Uint64, error) {
|
||||||
return n, nil
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Uint64) Scan(src any) error {
|
func (dst *Uint64) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = Uint64{}
|
*dst = Uint64{}
|
||||||
@ -63,7 +65,7 @@ func (dst *Uint64) Scan(src any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Uint64) Value() (driver.Value, error) {
|
func (src Uint64) Value() (driver.Value, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -194,7 +196,6 @@ func (encodePlanUint64CodecTextInt64Valuer) Encode(value any, buf []byte) (newBu
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (Uint64Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
func (Uint64Codec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
|
||||||
|
|
||||||
switch format {
|
switch format {
|
||||||
case BinaryFormatCode:
|
case BinaryFormatCode:
|
||||||
switch target.(type) {
|
switch target.(type) {
|
||||||
|
@ -20,11 +20,13 @@ type UUID struct {
|
|||||||
Valid bool
|
Valid bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ScanUUID implements the [UUIDScanner] interface.
|
||||||
func (b *UUID) ScanUUID(v UUID) error {
|
func (b *UUID) ScanUUID(v UUID) error {
|
||||||
*b = v
|
*b = v
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UUIDValue implements the [UUIDValuer] interface.
|
||||||
func (b UUID) UUIDValue() (UUID, error) {
|
func (b UUID) UUIDValue() (UUID, error) {
|
||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
@ -67,7 +69,7 @@ func encodeUUID(src [16]byte) string {
|
|||||||
return string(buf[:])
|
return string(buf[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *UUID) Scan(src any) error {
|
func (dst *UUID) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = UUID{}
|
*dst = UUID{}
|
||||||
@ -87,7 +89,7 @@ func (dst *UUID) Scan(src any) error {
|
|||||||
return fmt.Errorf("cannot scan %T", src)
|
return fmt.Errorf("cannot scan %T", src)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src UUID) Value() (driver.Value, error) {
|
func (src UUID) Value() (driver.Value, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -104,6 +106,7 @@ func (src UUID) String() string {
|
|||||||
return encodeUUID(src.Bytes)
|
return encodeUUID(src.Bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the [encoding/json.Marshaler] interface.
|
||||||
func (src UUID) MarshalJSON() ([]byte, error) {
|
func (src UUID) MarshalJSON() ([]byte, error) {
|
||||||
if !src.Valid {
|
if !src.Valid {
|
||||||
return []byte("null"), nil
|
return []byte("null"), nil
|
||||||
@ -116,6 +119,7 @@ func (src UUID) MarshalJSON() ([]byte, error) {
|
|||||||
return buff.Bytes(), nil
|
return buff.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface.
|
||||||
func (dst *UUID) UnmarshalJSON(src []byte) error {
|
func (dst *UUID) UnmarshalJSON(src []byte) error {
|
||||||
if bytes.Equal(src, []byte("null")) {
|
if bytes.Equal(src, []byte("null")) {
|
||||||
*dst = UUID{}
|
*dst = UUID{}
|
||||||
|
@ -8,9 +8,10 @@ import (
|
|||||||
|
|
||||||
type Float8 float64
|
type Float8 float64
|
||||||
|
|
||||||
|
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
|
||||||
func (Float8) SkipUnderlyingTypePlan() {}
|
func (Float8) SkipUnderlyingTypePlan() {}
|
||||||
|
|
||||||
// ScanFloat64 implements the Float64Scanner interface.
|
// ScanFloat64 implements the [pgtype.Float64Scanner] interface.
|
||||||
func (f *Float8) ScanFloat64(n pgtype.Float8) error {
|
func (f *Float8) ScanFloat64(n pgtype.Float8) error {
|
||||||
if !n.Valid {
|
if !n.Valid {
|
||||||
*f = 0
|
*f = 0
|
||||||
@ -22,6 +23,7 @@ func (f *Float8) ScanFloat64(n pgtype.Float8) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Float64Value implements the [pgtype.Float64Valuer] interface.
|
||||||
func (f Float8) Float64Value() (pgtype.Float8, error) {
|
func (f Float8) Float64Value() (pgtype.Float8, error) {
|
||||||
if f == 0 {
|
if f == 0 {
|
||||||
return pgtype.Float8{}, nil
|
return pgtype.Float8{}, nil
|
||||||
@ -29,7 +31,7 @@ func (f Float8) Float64Value() (pgtype.Float8, error) {
|
|||||||
return pgtype.Float8{Float64: float64(f), Valid: true}, nil
|
return pgtype.Float8{Float64: float64(f), Valid: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (f *Float8) Scan(src any) error {
|
func (f *Float8) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*f = 0
|
*f = 0
|
||||||
@ -47,7 +49,7 @@ func (f *Float8) Scan(src any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (f Float8) Value() (driver.Value, error) {
|
func (f Float8) Value() (driver.Value, error) {
|
||||||
if f == 0 {
|
if f == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -12,27 +12,36 @@ import (
|
|||||||
|
|
||||||
type Int2 int16
|
type Int2 int16
|
||||||
|
|
||||||
|
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
|
||||||
func (Int2) SkipUnderlyingTypePlan() {}
|
func (Int2) SkipUnderlyingTypePlan() {}
|
||||||
|
|
||||||
// ScanInt64 implements the Int64Scanner interface.
|
// ScanInt64 implements the [pgtype.Int64Scanner] interface.
|
||||||
func (dst *Int2) ScanInt64(n int64, valid bool) error {
|
func (dst *Int2) ScanInt64(n pgtype.Int8) error {
|
||||||
if !valid {
|
if !n.Valid {
|
||||||
*dst = 0
|
*dst = 0
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if n < math.MinInt16 {
|
if n.Int64 < math.MinInt16 {
|
||||||
return fmt.Errorf("%d is greater than maximum value for Int2", n)
|
return fmt.Errorf("%d is less than minimum value for Int2", n.Int64)
|
||||||
}
|
}
|
||||||
if n > math.MaxInt16 {
|
if n.Int64 > math.MaxInt16 {
|
||||||
return fmt.Errorf("%d is greater than maximum value for Int2", n)
|
return fmt.Errorf("%d is greater than maximum value for Int2", n.Int64)
|
||||||
}
|
}
|
||||||
*dst = Int2(n)
|
*dst = Int2(n.Int64)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Int64Value implements the [pgtype.Int64Valuer] interface.
|
||||||
|
func (src Int2) Int64Value() (pgtype.Int8, error) {
|
||||||
|
if src == 0 {
|
||||||
|
return pgtype.Int8{}, nil
|
||||||
|
}
|
||||||
|
return pgtype.Int8{Int64: int64(src), Valid: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Int2) Scan(src any) error {
|
func (dst *Int2) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = 0
|
*dst = 0
|
||||||
@ -50,7 +59,7 @@ func (dst *Int2) Scan(src any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Int2) Value() (driver.Value, error) {
|
func (src Int2) Value() (driver.Value, error) {
|
||||||
if src == 0 {
|
if src == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -60,27 +69,36 @@ func (src Int2) Value() (driver.Value, error) {
|
|||||||
|
|
||||||
type Int4 int32
|
type Int4 int32
|
||||||
|
|
||||||
|
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
|
||||||
func (Int4) SkipUnderlyingTypePlan() {}
|
func (Int4) SkipUnderlyingTypePlan() {}
|
||||||
|
|
||||||
// ScanInt64 implements the Int64Scanner interface.
|
// ScanInt64 implements the [pgtype.Int64Scanner] interface.
|
||||||
func (dst *Int4) ScanInt64(n int64, valid bool) error {
|
func (dst *Int4) ScanInt64(n pgtype.Int8) error {
|
||||||
if !valid {
|
if !n.Valid {
|
||||||
*dst = 0
|
*dst = 0
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if n < math.MinInt32 {
|
if n.Int64 < math.MinInt32 {
|
||||||
return fmt.Errorf("%d is greater than maximum value for Int4", n)
|
return fmt.Errorf("%d is less than minimum value for Int4", n.Int64)
|
||||||
}
|
}
|
||||||
if n > math.MaxInt32 {
|
if n.Int64 > math.MaxInt32 {
|
||||||
return fmt.Errorf("%d is greater than maximum value for Int4", n)
|
return fmt.Errorf("%d is greater than maximum value for Int4", n.Int64)
|
||||||
}
|
}
|
||||||
*dst = Int4(n)
|
*dst = Int4(n.Int64)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Int64Value implements the [pgtype.Int64Valuer] interface.
|
||||||
|
func (src Int4) Int64Value() (pgtype.Int8, error) {
|
||||||
|
if src == 0 {
|
||||||
|
return pgtype.Int8{}, nil
|
||||||
|
}
|
||||||
|
return pgtype.Int8{Int64: int64(src), Valid: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Int4) Scan(src any) error {
|
func (dst *Int4) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = 0
|
*dst = 0
|
||||||
@ -98,7 +116,7 @@ func (dst *Int4) Scan(src any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Int4) Value() (driver.Value, error) {
|
func (src Int4) Value() (driver.Value, error) {
|
||||||
if src == 0 {
|
if src == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@ -108,27 +126,36 @@ func (src Int4) Value() (driver.Value, error) {
|
|||||||
|
|
||||||
type Int8 int64
|
type Int8 int64
|
||||||
|
|
||||||
|
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
|
||||||
func (Int8) SkipUnderlyingTypePlan() {}
|
func (Int8) SkipUnderlyingTypePlan() {}
|
||||||
|
|
||||||
// ScanInt64 implements the Int64Scanner interface.
|
// ScanInt64 implements the [pgtype.Int64Scanner] interface.
|
||||||
func (dst *Int8) ScanInt64(n int64, valid bool) error {
|
func (dst *Int8) ScanInt64(n pgtype.Int8) error {
|
||||||
if !valid {
|
if !n.Valid {
|
||||||
*dst = 0
|
*dst = 0
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if n < math.MinInt64 {
|
if n.Int64 < math.MinInt64 {
|
||||||
return fmt.Errorf("%d is greater than maximum value for Int8", n)
|
return fmt.Errorf("%d is less than minimum value for Int8", n.Int64)
|
||||||
}
|
}
|
||||||
if n > math.MaxInt64 {
|
if n.Int64 > math.MaxInt64 {
|
||||||
return fmt.Errorf("%d is greater than maximum value for Int8", n)
|
return fmt.Errorf("%d is greater than maximum value for Int8", n.Int64)
|
||||||
}
|
}
|
||||||
*dst = Int8(n)
|
*dst = Int8(n.Int64)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Int64Value implements the [pgtype.Int64Valuer] interface.
|
||||||
|
func (src Int8) Int64Value() (pgtype.Int8, error) {
|
||||||
|
if src == 0 {
|
||||||
|
return pgtype.Int8{}, nil
|
||||||
|
}
|
||||||
|
return pgtype.Int8{Int64: int64(src), Valid: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Int8) Scan(src any) error {
|
func (dst *Int8) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = 0
|
*dst = 0
|
||||||
@ -146,7 +173,7 @@ func (dst *Int8) Scan(src any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Int8) Value() (driver.Value, error) {
|
func (src Int8) Value() (driver.Value, error) {
|
||||||
if src == 0 {
|
if src == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -12,27 +12,36 @@ import (
|
|||||||
<% pg_bit_size = pg_byte_size * 8 %>
|
<% pg_bit_size = pg_byte_size * 8 %>
|
||||||
type Int<%= pg_byte_size %> int<%= pg_bit_size %>
|
type Int<%= pg_byte_size %> int<%= pg_bit_size %>
|
||||||
|
|
||||||
|
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
|
||||||
func (Int<%= pg_byte_size %>) SkipUnderlyingTypePlan() {}
|
func (Int<%= pg_byte_size %>) SkipUnderlyingTypePlan() {}
|
||||||
|
|
||||||
// ScanInt64 implements the Int64Scanner interface.
|
// ScanInt64 implements the [pgtype.Int64Scanner] interface.
|
||||||
func (dst *Int<%= pg_byte_size %>) ScanInt64(n int64, valid bool) error {
|
func (dst *Int<%= pg_byte_size %>) ScanInt64(n pgtype.Int8) error {
|
||||||
if !valid {
|
if !n.Valid {
|
||||||
*dst = 0
|
*dst = 0
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if n < math.MinInt<%= pg_bit_size %> {
|
if n.Int64 < math.MinInt<%= pg_bit_size %> {
|
||||||
return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n)
|
return fmt.Errorf("%d is less than minimum value for Int<%= pg_byte_size %>", n.Int64)
|
||||||
}
|
}
|
||||||
if n > math.MaxInt<%= pg_bit_size %> {
|
if n.Int64 > math.MaxInt<%= pg_bit_size %> {
|
||||||
return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n)
|
return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int64)
|
||||||
}
|
}
|
||||||
*dst = Int<%= pg_byte_size %>(n)
|
*dst = Int<%= pg_byte_size %>(n.Int64)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Int64Value implements the [pgtype.Int64Valuer] interface.
|
||||||
|
func (src Int<%= pg_byte_size %>) Int64Value() (pgtype.Int8, error) {
|
||||||
|
if src == 0 {
|
||||||
|
return pgtype.Int8{}, nil
|
||||||
|
}
|
||||||
|
return pgtype.Int8{Int64: int64(src), Valid: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Int<%= pg_byte_size %>) Scan(src any) error {
|
func (dst *Int<%= pg_byte_size %>) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = 0
|
*dst = 0
|
||||||
@ -50,7 +59,7 @@ func (dst *Int<%= pg_byte_size %>) Scan(src any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) {
|
func (src Int<%= pg_byte_size %>) Value() (driver.Value, error) {
|
||||||
if src == 0 {
|
if src == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -8,9 +8,10 @@ import (
|
|||||||
|
|
||||||
type Text string
|
type Text string
|
||||||
|
|
||||||
|
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
|
||||||
func (Text) SkipUnderlyingTypePlan() {}
|
func (Text) SkipUnderlyingTypePlan() {}
|
||||||
|
|
||||||
// ScanText implements the TextScanner interface.
|
// ScanText implements the [pgtype.TextScanner] interface.
|
||||||
func (dst *Text) ScanText(v pgtype.Text) error {
|
func (dst *Text) ScanText(v pgtype.Text) error {
|
||||||
if !v.Valid {
|
if !v.Valid {
|
||||||
*dst = ""
|
*dst = ""
|
||||||
@ -22,7 +23,7 @@ func (dst *Text) ScanText(v pgtype.Text) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (dst *Text) Scan(src any) error {
|
func (dst *Text) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*dst = ""
|
*dst = ""
|
||||||
@ -40,7 +41,7 @@ func (dst *Text) Scan(src any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (src Text) Value() (driver.Value, error) {
|
func (src Text) Value() (driver.Value, error) {
|
||||||
if src == "" {
|
if src == "" {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -10,8 +10,10 @@ import (
|
|||||||
|
|
||||||
type Timestamp time.Time
|
type Timestamp time.Time
|
||||||
|
|
||||||
|
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
|
||||||
func (Timestamp) SkipUnderlyingTypePlan() {}
|
func (Timestamp) SkipUnderlyingTypePlan() {}
|
||||||
|
|
||||||
|
// ScanTimestamp implements the [pgtype.TimestampScanner] interface.
|
||||||
func (ts *Timestamp) ScanTimestamp(v pgtype.Timestamp) error {
|
func (ts *Timestamp) ScanTimestamp(v pgtype.Timestamp) error {
|
||||||
if !v.Valid {
|
if !v.Valid {
|
||||||
*ts = Timestamp{}
|
*ts = Timestamp{}
|
||||||
@ -31,6 +33,7 @@ func (ts *Timestamp) ScanTimestamp(v pgtype.Timestamp) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TimestampValue implements the [pgtype.TimestampValuer] interface.
|
||||||
func (ts Timestamp) TimestampValue() (pgtype.Timestamp, error) {
|
func (ts Timestamp) TimestampValue() (pgtype.Timestamp, error) {
|
||||||
if time.Time(ts).IsZero() {
|
if time.Time(ts).IsZero() {
|
||||||
return pgtype.Timestamp{}, nil
|
return pgtype.Timestamp{}, nil
|
||||||
@ -39,7 +42,7 @@ func (ts Timestamp) TimestampValue() (pgtype.Timestamp, error) {
|
|||||||
return pgtype.Timestamp{Time: time.Time(ts), Valid: true}, nil
|
return pgtype.Timestamp{Time: time.Time(ts), Valid: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (ts *Timestamp) Scan(src any) error {
|
func (ts *Timestamp) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*ts = Timestamp{}
|
*ts = Timestamp{}
|
||||||
@ -57,7 +60,7 @@ func (ts *Timestamp) Scan(src any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (ts Timestamp) Value() (driver.Value, error) {
|
func (ts Timestamp) Value() (driver.Value, error) {
|
||||||
if time.Time(ts).IsZero() {
|
if time.Time(ts).IsZero() {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -10,8 +10,10 @@ import (
|
|||||||
|
|
||||||
type Timestamptz time.Time
|
type Timestamptz time.Time
|
||||||
|
|
||||||
|
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
|
||||||
func (Timestamptz) SkipUnderlyingTypePlan() {}
|
func (Timestamptz) SkipUnderlyingTypePlan() {}
|
||||||
|
|
||||||
|
// ScanTimestamptz implements the [pgtype.TimestamptzScanner] interface.
|
||||||
func (ts *Timestamptz) ScanTimestamptz(v pgtype.Timestamptz) error {
|
func (ts *Timestamptz) ScanTimestamptz(v pgtype.Timestamptz) error {
|
||||||
if !v.Valid {
|
if !v.Valid {
|
||||||
*ts = Timestamptz{}
|
*ts = Timestamptz{}
|
||||||
@ -31,6 +33,7 @@ func (ts *Timestamptz) ScanTimestamptz(v pgtype.Timestamptz) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TimestamptzValue implements the [pgtype.TimestamptzValuer] interface.
|
||||||
func (ts Timestamptz) TimestamptzValue() (pgtype.Timestamptz, error) {
|
func (ts Timestamptz) TimestamptzValue() (pgtype.Timestamptz, error) {
|
||||||
if time.Time(ts).IsZero() {
|
if time.Time(ts).IsZero() {
|
||||||
return pgtype.Timestamptz{}, nil
|
return pgtype.Timestamptz{}, nil
|
||||||
@ -39,7 +42,7 @@ func (ts Timestamptz) TimestamptzValue() (pgtype.Timestamptz, error) {
|
|||||||
return pgtype.Timestamptz{Time: time.Time(ts), Valid: true}, nil
|
return pgtype.Timestamptz{Time: time.Time(ts), Valid: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (ts *Timestamptz) Scan(src any) error {
|
func (ts *Timestamptz) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*ts = Timestamptz{}
|
*ts = Timestamptz{}
|
||||||
@ -57,7 +60,7 @@ func (ts *Timestamptz) Scan(src any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (ts Timestamptz) Value() (driver.Value, error) {
|
func (ts Timestamptz) Value() (driver.Value, error) {
|
||||||
if time.Time(ts).IsZero() {
|
if time.Time(ts).IsZero() {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -8,9 +8,10 @@ import (
|
|||||||
|
|
||||||
type UUID [16]byte
|
type UUID [16]byte
|
||||||
|
|
||||||
|
// SkipUnderlyingTypePlan implements the [pgtype.SkipUnderlyingTypePlanner] interface.
|
||||||
func (UUID) SkipUnderlyingTypePlan() {}
|
func (UUID) SkipUnderlyingTypePlan() {}
|
||||||
|
|
||||||
// ScanUUID implements the UUIDScanner interface.
|
// ScanUUID implements the [pgtype.UUIDScanner] interface.
|
||||||
func (u *UUID) ScanUUID(v pgtype.UUID) error {
|
func (u *UUID) ScanUUID(v pgtype.UUID) error {
|
||||||
if !v.Valid {
|
if !v.Valid {
|
||||||
*u = UUID{}
|
*u = UUID{}
|
||||||
@ -22,6 +23,7 @@ func (u *UUID) ScanUUID(v pgtype.UUID) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UUIDValue implements the [pgtype.UUIDValuer] interface.
|
||||||
func (u UUID) UUIDValue() (pgtype.UUID, error) {
|
func (u UUID) UUIDValue() (pgtype.UUID, error) {
|
||||||
if u == (UUID{}) {
|
if u == (UUID{}) {
|
||||||
return pgtype.UUID{}, nil
|
return pgtype.UUID{}, nil
|
||||||
@ -29,7 +31,7 @@ func (u UUID) UUIDValue() (pgtype.UUID, error) {
|
|||||||
return pgtype.UUID{Bytes: u, Valid: true}, nil
|
return pgtype.UUID{Bytes: u, Valid: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scan implements the database/sql Scanner interface.
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
func (u *UUID) Scan(src any) error {
|
func (u *UUID) Scan(src any) error {
|
||||||
if src == nil {
|
if src == nil {
|
||||||
*u = UUID{}
|
*u = UUID{}
|
||||||
@ -47,7 +49,7 @@ func (u *UUID) Scan(src any) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Value implements the database/sql/driver Valuer interface.
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
func (u UUID) Value() (driver.Value, error) {
|
func (u UUID) Value() (driver.Value, error) {
|
||||||
if u == (UUID{}) {
|
if u == (UUID{}) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
|
@ -97,7 +97,8 @@ func testCopyFrom(t *testing.T, ctx context.Context, db interface {
|
|||||||
execer
|
execer
|
||||||
queryer
|
queryer
|
||||||
copyFromer
|
copyFromer
|
||||||
}) {
|
},
|
||||||
|
) {
|
||||||
_, err := db.Exec(ctx, `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`)
|
_, err := db.Exec(ctx, `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@ -141,6 +142,7 @@ func assertConfigsEqual(t *testing.T, expected, actual *pgxpool.Config, testName
|
|||||||
// Can't test function equality, so just test that they are set or not.
|
// Can't test function equality, so just test that they are set or not.
|
||||||
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
|
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
|
||||||
assert.Equalf(t, expected.BeforeAcquire == nil, actual.BeforeAcquire == nil, "%s - BeforeAcquire", testName)
|
assert.Equalf(t, expected.BeforeAcquire == nil, actual.BeforeAcquire == nil, "%s - BeforeAcquire", testName)
|
||||||
|
assert.Equalf(t, expected.PrepareConn == nil, actual.PrepareConn == nil, "%s - PrepareConn", testName)
|
||||||
assert.Equalf(t, expected.AfterRelease == nil, actual.AfterRelease == nil, "%s - AfterRelease", testName)
|
assert.Equalf(t, expected.AfterRelease == nil, actual.AfterRelease == nil, "%s - AfterRelease", testName)
|
||||||
|
|
||||||
assert.Equalf(t, expected.MaxConnLifetime, actual.MaxConnLifetime, "%s - MaxConnLifetime", testName)
|
assert.Equalf(t, expected.MaxConnLifetime, actual.MaxConnLifetime, "%s - MaxConnLifetime", testName)
|
||||||
|
115
pgxpool/pool.go
115
pgxpool/pool.go
@ -2,7 +2,7 @@ package pgxpool
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"errors"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -15,12 +15,14 @@ import (
|
|||||||
"github.com/jackc/puddle/v2"
|
"github.com/jackc/puddle/v2"
|
||||||
)
|
)
|
||||||
|
|
||||||
var defaultMaxConns = int32(4)
|
var (
|
||||||
var defaultMinConns = int32(0)
|
defaultMaxConns = int32(4)
|
||||||
var defaultMinIdleConns = int32(0)
|
defaultMinConns = int32(0)
|
||||||
var defaultMaxConnLifetime = time.Hour
|
defaultMinIdleConns = int32(0)
|
||||||
var defaultMaxConnIdleTime = time.Minute * 30
|
defaultMaxConnLifetime = time.Hour
|
||||||
var defaultHealthCheckPeriod = time.Minute
|
defaultMaxConnIdleTime = time.Minute * 30
|
||||||
|
defaultHealthCheckPeriod = time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
type connResource struct {
|
type connResource struct {
|
||||||
conn *pgx.Conn
|
conn *pgx.Conn
|
||||||
@ -84,9 +86,10 @@ type Pool struct {
|
|||||||
config *Config
|
config *Config
|
||||||
beforeConnect func(context.Context, *pgx.ConnConfig) error
|
beforeConnect func(context.Context, *pgx.ConnConfig) error
|
||||||
afterConnect func(context.Context, *pgx.Conn) error
|
afterConnect func(context.Context, *pgx.Conn) error
|
||||||
beforeAcquire func(context.Context, *pgx.Conn) bool
|
prepareConn func(context.Context, *pgx.Conn) (bool, error)
|
||||||
afterRelease func(*pgx.Conn) bool
|
afterRelease func(*pgx.Conn) bool
|
||||||
beforeClose func(*pgx.Conn)
|
beforeClose func(*pgx.Conn)
|
||||||
|
shouldPing func(context.Context, ShouldPingParams) bool
|
||||||
minConns int32
|
minConns int32
|
||||||
minIdleConns int32
|
minIdleConns int32
|
||||||
maxConns int32
|
maxConns int32
|
||||||
@ -104,6 +107,12 @@ type Pool struct {
|
|||||||
closeChan chan struct{}
|
closeChan chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ShouldPingParams are the parameters passed to ShouldPing.
|
||||||
|
type ShouldPingParams struct {
|
||||||
|
Conn *pgx.Conn
|
||||||
|
IdleDuration time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
// Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be
|
// Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be
|
||||||
// modified.
|
// modified.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@ -119,8 +128,23 @@ type Config struct {
|
|||||||
// BeforeAcquire is called before a connection is acquired from the pool. It must return true to allow the
|
// BeforeAcquire is called before a connection is acquired from the pool. It must return true to allow the
|
||||||
// acquisition or false to indicate that the connection should be destroyed and a different connection should be
|
// acquisition or false to indicate that the connection should be destroyed and a different connection should be
|
||||||
// acquired.
|
// acquired.
|
||||||
|
//
|
||||||
|
// Deprecated: Use PrepareConn instead. If both PrepareConn and BeforeAcquire are set, PrepareConn will take
|
||||||
|
// precedence, ignoring BeforeAcquire.
|
||||||
BeforeAcquire func(context.Context, *pgx.Conn) bool
|
BeforeAcquire func(context.Context, *pgx.Conn) bool
|
||||||
|
|
||||||
|
// PrepareConn is called before a connection is acquired from the pool. If this function returns true, the connection
|
||||||
|
// is considered valid, otherwise the connection is destroyed. If the function returns a non-nil error, the instigating
|
||||||
|
// query will fail with the returned error.
|
||||||
|
//
|
||||||
|
// Specifically, this means that:
|
||||||
|
//
|
||||||
|
// - If it returns true and a nil error, the query proceeds as normal.
|
||||||
|
// - If it returns true and an error, the connection will be returned to the pool, and the instigating query will fail with the returned error.
|
||||||
|
// - If it returns false, and an error, the connection will be destroyed, and the query will fail with the returned error.
|
||||||
|
// - If it returns false and a nil error, the connection will be destroyed, and the instigating query will be retried on a new connection.
|
||||||
|
PrepareConn func(context.Context, *pgx.Conn) (bool, error)
|
||||||
|
|
||||||
// AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to
|
// AfterRelease is called after a connection is released, but before it is returned to the pool. It must return true to
|
||||||
// return the connection to the pool or false to destroy the connection.
|
// return the connection to the pool or false to destroy the connection.
|
||||||
AfterRelease func(*pgx.Conn) bool
|
AfterRelease func(*pgx.Conn) bool
|
||||||
@ -128,6 +152,10 @@ type Config struct {
|
|||||||
// BeforeClose is called right before a connection is closed and removed from the pool.
|
// BeforeClose is called right before a connection is closed and removed from the pool.
|
||||||
BeforeClose func(*pgx.Conn)
|
BeforeClose func(*pgx.Conn)
|
||||||
|
|
||||||
|
// ShouldPing is called after a connection is acquired from the pool. If it returns true, the connection is pinged to check for liveness.
|
||||||
|
// If this func is not set, the default behavior is to ping connections that have been idle for at least 1 second.
|
||||||
|
ShouldPing func(context.Context, ShouldPingParams) bool
|
||||||
|
|
||||||
// MaxConnLifetime is the duration since creation after which a connection will be automatically closed.
|
// MaxConnLifetime is the duration since creation after which a connection will be automatically closed.
|
||||||
MaxConnLifetime time.Duration
|
MaxConnLifetime time.Duration
|
||||||
|
|
||||||
@ -190,11 +218,18 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
|
|||||||
panic("config must be created by ParseConfig")
|
panic("config must be created by ParseConfig")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prepareConn := config.PrepareConn
|
||||||
|
if prepareConn == nil && config.BeforeAcquire != nil {
|
||||||
|
prepareConn = func(ctx context.Context, conn *pgx.Conn) (bool, error) {
|
||||||
|
return config.BeforeAcquire(ctx, conn), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
p := &Pool{
|
p := &Pool{
|
||||||
config: config,
|
config: config,
|
||||||
beforeConnect: config.BeforeConnect,
|
beforeConnect: config.BeforeConnect,
|
||||||
afterConnect: config.AfterConnect,
|
afterConnect: config.AfterConnect,
|
||||||
beforeAcquire: config.BeforeAcquire,
|
prepareConn: prepareConn,
|
||||||
afterRelease: config.AfterRelease,
|
afterRelease: config.AfterRelease,
|
||||||
beforeClose: config.BeforeClose,
|
beforeClose: config.BeforeClose,
|
||||||
minConns: config.MinConns,
|
minConns: config.MinConns,
|
||||||
@ -216,6 +251,14 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
|
|||||||
p.releaseTracer = t
|
p.releaseTracer = t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.ShouldPing != nil {
|
||||||
|
p.shouldPing = config.ShouldPing
|
||||||
|
} else {
|
||||||
|
p.shouldPing = func(ctx context.Context, params ShouldPingParams) bool {
|
||||||
|
return params.IdleDuration > time.Second
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
p.p, err = puddle.NewPool(
|
p.p, err = puddle.NewPool(
|
||||||
&puddle.Config[*connResource]{
|
&puddle.Config[*connResource]{
|
||||||
@ -321,10 +364,10 @@ func ParseConfig(connString string) (*Config, error) {
|
|||||||
delete(connConfig.Config.RuntimeParams, "pool_max_conns")
|
delete(connConfig.Config.RuntimeParams, "pool_max_conns")
|
||||||
n, err := strconv.ParseInt(s, 10, 32)
|
n, err := strconv.ParseInt(s, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot parse pool_max_conns: %w", err)
|
return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conns", err)
|
||||||
}
|
}
|
||||||
if n < 1 {
|
if n < 1 {
|
||||||
return nil, fmt.Errorf("pool_max_conns too small: %d", n)
|
return nil, pgconn.NewParseConfigError(connString, "pool_max_conns too small", err)
|
||||||
}
|
}
|
||||||
config.MaxConns = int32(n)
|
config.MaxConns = int32(n)
|
||||||
} else {
|
} else {
|
||||||
@ -338,7 +381,7 @@ func ParseConfig(connString string) (*Config, error) {
|
|||||||
delete(connConfig.Config.RuntimeParams, "pool_min_conns")
|
delete(connConfig.Config.RuntimeParams, "pool_min_conns")
|
||||||
n, err := strconv.ParseInt(s, 10, 32)
|
n, err := strconv.ParseInt(s, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot parse pool_min_conns: %w", err)
|
return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_min_conns", err)
|
||||||
}
|
}
|
||||||
config.MinConns = int32(n)
|
config.MinConns = int32(n)
|
||||||
} else {
|
} else {
|
||||||
@ -349,7 +392,7 @@ func ParseConfig(connString string) (*Config, error) {
|
|||||||
delete(connConfig.Config.RuntimeParams, "pool_min_idle_conns")
|
delete(connConfig.Config.RuntimeParams, "pool_min_idle_conns")
|
||||||
n, err := strconv.ParseInt(s, 10, 32)
|
n, err := strconv.ParseInt(s, 10, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot parse pool_min_idle_conns: %w", err)
|
return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_min_idle_conns", err)
|
||||||
}
|
}
|
||||||
config.MinIdleConns = int32(n)
|
config.MinIdleConns = int32(n)
|
||||||
} else {
|
} else {
|
||||||
@ -360,7 +403,7 @@ func ParseConfig(connString string) (*Config, error) {
|
|||||||
delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime")
|
delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime")
|
||||||
d, err := time.ParseDuration(s)
|
d, err := time.ParseDuration(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid pool_max_conn_lifetime: %w", err)
|
return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_lifetime", err)
|
||||||
}
|
}
|
||||||
config.MaxConnLifetime = d
|
config.MaxConnLifetime = d
|
||||||
} else {
|
} else {
|
||||||
@ -371,7 +414,7 @@ func ParseConfig(connString string) (*Config, error) {
|
|||||||
delete(connConfig.Config.RuntimeParams, "pool_max_conn_idle_time")
|
delete(connConfig.Config.RuntimeParams, "pool_max_conn_idle_time")
|
||||||
d, err := time.ParseDuration(s)
|
d, err := time.ParseDuration(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid pool_max_conn_idle_time: %w", err)
|
return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_idle_time", err)
|
||||||
}
|
}
|
||||||
config.MaxConnIdleTime = d
|
config.MaxConnIdleTime = d
|
||||||
} else {
|
} else {
|
||||||
@ -382,7 +425,7 @@ func ParseConfig(connString string) (*Config, error) {
|
|||||||
delete(connConfig.Config.RuntimeParams, "pool_health_check_period")
|
delete(connConfig.Config.RuntimeParams, "pool_health_check_period")
|
||||||
d, err := time.ParseDuration(s)
|
d, err := time.ParseDuration(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid pool_health_check_period: %w", err)
|
return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_health_check_period", err)
|
||||||
}
|
}
|
||||||
config.HealthCheckPeriod = d
|
config.HealthCheckPeriod = d
|
||||||
} else {
|
} else {
|
||||||
@ -393,7 +436,7 @@ func ParseConfig(connString string) (*Config, error) {
|
|||||||
delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime_jitter")
|
delete(connConfig.Config.RuntimeParams, "pool_max_conn_lifetime_jitter")
|
||||||
d, err := time.ParseDuration(s)
|
d, err := time.ParseDuration(s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid pool_max_conn_lifetime_jitter: %w", err)
|
return nil, pgconn.NewParseConfigError(connString, "cannot parse pool_max_conn_lifetime_jitter", err)
|
||||||
}
|
}
|
||||||
config.MaxConnLifetimeJitter = d
|
config.MaxConnLifetimeJitter = d
|
||||||
}
|
}
|
||||||
@ -545,7 +588,10 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
for {
|
// Try to acquire from the connection pool up to maxConns + 1 times, so that
|
||||||
|
// any that fatal errors would empty the pool and still at least try 1 fresh
|
||||||
|
// connection.
|
||||||
|
for range p.maxConns + 1 {
|
||||||
res, err := p.p.Acquire(ctx)
|
res, err := p.p.Acquire(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -553,7 +599,8 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) {
|
|||||||
|
|
||||||
cr := res.Value()
|
cr := res.Value()
|
||||||
|
|
||||||
if res.IdleDuration() > time.Second {
|
shouldPingParams := ShouldPingParams{Conn: cr.conn, IdleDuration: res.IdleDuration()}
|
||||||
|
if p.shouldPing(ctx, shouldPingParams) {
|
||||||
err := cr.conn.Ping(ctx)
|
err := cr.conn.Ping(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
res.Destroy()
|
res.Destroy()
|
||||||
@ -561,12 +608,25 @@ func (p *Pool) Acquire(ctx context.Context) (c *Conn, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) {
|
if p.prepareConn != nil {
|
||||||
return cr.getConn(p, res), nil
|
ok, err := p.prepareConn(ctx, cr.conn)
|
||||||
|
if !ok {
|
||||||
|
res.Destroy()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if ok {
|
||||||
|
res.Release()
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
res.Destroy()
|
return cr.getConn(p, res), nil
|
||||||
}
|
}
|
||||||
|
return nil, errors.New("pgxpool: detected infinite loop acquiring connection; likely bug in PrepareConn or BeforeAcquire hook")
|
||||||
}
|
}
|
||||||
|
|
||||||
// AcquireFunc acquires a *Conn and calls f with that *Conn. ctx will only affect the Acquire. It has no effect on the
|
// AcquireFunc acquires a *Conn and calls f with that *Conn. ctx will only affect the Acquire. It has no effect on the
|
||||||
@ -589,11 +649,14 @@ func (p *Pool) AcquireAllIdle(ctx context.Context) []*Conn {
|
|||||||
conns := make([]*Conn, 0, len(resources))
|
conns := make([]*Conn, 0, len(resources))
|
||||||
for _, res := range resources {
|
for _, res := range resources {
|
||||||
cr := res.Value()
|
cr := res.Value()
|
||||||
if p.beforeAcquire == nil || p.beforeAcquire(ctx, cr.conn) {
|
if p.prepareConn != nil {
|
||||||
conns = append(conns, cr.getConn(p, res))
|
ok, err := p.prepareConn(ctx, cr.conn)
|
||||||
} else {
|
if !ok || err != nil {
|
||||||
res.Destroy()
|
res.Destroy()
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
conns = append(conns, cr.getConn(p, res))
|
||||||
}
|
}
|
||||||
|
|
||||||
return conns
|
return conns
|
||||||
|
@ -204,6 +204,47 @@ func TestPoolAcquireChecksIdleConns(t *testing.T) {
|
|||||||
require.NotContains(t, pids, cPID)
|
require.NotContains(t, pids, cPID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPoolAcquireChecksIdleConnsWithShouldPing(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
controllerConn, err := pgx.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer controllerConn.Close(ctx)
|
||||||
|
|
||||||
|
config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Replace the default ShouldPing func
|
||||||
|
var shouldPingLastCalledWith *pgxpool.ShouldPingParams
|
||||||
|
config.ShouldPing = func(ctx context.Context, params pgxpool.ShouldPingParams) bool {
|
||||||
|
shouldPingLastCalledWith = ¶ms
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
pool, err := pgxpool.NewWithConfig(ctx, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer pool.Close()
|
||||||
|
|
||||||
|
c, err := pool.Acquire(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
c.Release()
|
||||||
|
|
||||||
|
time.Sleep(time.Millisecond * 200)
|
||||||
|
|
||||||
|
c, err = pool.Acquire(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
conn := c.Conn()
|
||||||
|
|
||||||
|
require.NotNil(t, shouldPingLastCalledWith)
|
||||||
|
assert.Equal(t, conn, shouldPingLastCalledWith.Conn)
|
||||||
|
assert.InDelta(t, time.Millisecond*200, shouldPingLastCalledWith.IdleDuration, float64(time.Millisecond*100))
|
||||||
|
|
||||||
|
c.Release()
|
||||||
|
}
|
||||||
|
|
||||||
func TestPoolAcquireFunc(t *testing.T) {
|
func TestPoolAcquireFunc(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@ -330,6 +371,64 @@ func TestPoolBeforeAcquire(t *testing.T) {
|
|||||||
assert.EqualValues(t, 12, acquireAttempts)
|
assert.EqualValues(t, 12, acquireAttempts)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPoolPrepareConn(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
config, err := pgxpool.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
acquireAttempts := 0
|
||||||
|
|
||||||
|
config.PrepareConn = func(context.Context, *pgx.Conn) (bool, error) {
|
||||||
|
acquireAttempts++
|
||||||
|
var err error
|
||||||
|
if acquireAttempts%3 == 0 {
|
||||||
|
err = errors.New("PrepareConn error")
|
||||||
|
}
|
||||||
|
return acquireAttempts%2 == 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
db, err := pgxpool.NewWithConfig(ctx, config)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(db.Close)
|
||||||
|
|
||||||
|
var errorCount int
|
||||||
|
conns := make([]*pgxpool.Conn, 0, 4)
|
||||||
|
for {
|
||||||
|
conn, err := db.Acquire(ctx)
|
||||||
|
if err != nil {
|
||||||
|
errorCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
conns = append(conns, conn)
|
||||||
|
if len(conns) == 4 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const wantErrorCount = 3
|
||||||
|
assert.Equal(t, wantErrorCount, errorCount, "Acquire() should have failed %d times", wantErrorCount)
|
||||||
|
|
||||||
|
for _, c := range conns {
|
||||||
|
c.Release()
|
||||||
|
}
|
||||||
|
waitForReleaseToComplete()
|
||||||
|
|
||||||
|
assert.EqualValues(t, len(conns)*2+wantErrorCount-1, acquireAttempts)
|
||||||
|
|
||||||
|
conns = db.AcquireAllIdle(ctx)
|
||||||
|
assert.Len(t, conns, 1)
|
||||||
|
|
||||||
|
for _, c := range conns {
|
||||||
|
c.Release()
|
||||||
|
}
|
||||||
|
waitForReleaseToComplete()
|
||||||
|
|
||||||
|
assert.EqualValues(t, 14, acquireAttempts)
|
||||||
|
}
|
||||||
|
|
||||||
func TestPoolAfterRelease(t *testing.T) {
|
func TestPoolAfterRelease(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
@ -677,7 +776,6 @@ func TestPoolQuery(t *testing.T) {
|
|||||||
stats = pool.Stat()
|
stats = pool.Stat()
|
||||||
assert.EqualValues(t, 0, stats.AcquiredConns())
|
assert.EqualValues(t, 0, stats.AcquiredConns())
|
||||||
assert.EqualValues(t, 1, stats.TotalConns())
|
assert.EqualValues(t, 1, stats.TotalConns())
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPoolQueryRow(t *testing.T) {
|
func TestPoolQueryRow(t *testing.T) {
|
||||||
@ -1082,9 +1180,9 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) {
|
|||||||
acquireAttempts := int64(0)
|
acquireAttempts := int64(0)
|
||||||
connectAttempts := int64(0)
|
connectAttempts := int64(0)
|
||||||
|
|
||||||
config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool {
|
config.PrepareConn = func(ctx context.Context, conn *pgx.Conn) (bool, error) {
|
||||||
atomic.AddInt64(&acquireAttempts, 1)
|
atomic.AddInt64(&acquireAttempts, 1)
|
||||||
return true
|
return true, nil
|
||||||
}
|
}
|
||||||
config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error {
|
config.BeforeConnect = func(ctx context.Context, cfg *pgx.ConnConfig) error {
|
||||||
atomic.AddInt64(&connectAttempts, 1)
|
atomic.AddInt64(&connectAttempts, 1)
|
||||||
@ -1105,7 +1203,6 @@ func TestConnectEagerlyReachesMinPoolSize(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.Fatal("did not reach min pool size")
|
t.Fatal("did not reach min pool size")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPoolSendBatchBatchCloseTwice(t *testing.T) {
|
func TestPoolSendBatchBatchCloseTwice(t *testing.T) {
|
||||||
|
@ -568,7 +568,6 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) {
|
|||||||
ensureConnValid(t, conn)
|
ensureConnValid(t, conn)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryEncodeError(t *testing.T) {
|
func TestQueryEncodeError(t *testing.T) {
|
||||||
@ -2208,7 +2207,6 @@ insert into products (name, price) values
|
|||||||
}
|
}
|
||||||
|
|
||||||
rows, err := conn.Query(ctx, "select name, price from products where price < $1 order by price desc", 12)
|
rows, err := conn.Query(ctx, "select name, price from products where price < $1 order by price desc", 12)
|
||||||
|
|
||||||
// It is unnecessary to check err. If an error occurred it will be returned by rows.Err() later. But in rare
|
// It is unnecessary to check err. If an error occurred it will be returned by rows.Err() later. But in rare
|
||||||
// cases it may be useful to detect the error as early as possible.
|
// cases it may be useful to detect the error as early as possible.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
23
rows.go
23
rows.go
@ -41,22 +41,19 @@ type Rows interface {
|
|||||||
// when there was an error executing the query.
|
// when there was an error executing the query.
|
||||||
FieldDescriptions() []pgconn.FieldDescription
|
FieldDescriptions() []pgconn.FieldDescription
|
||||||
|
|
||||||
// Next prepares the next row for reading. It returns true if there is another
|
// Next prepares the next row for reading. It returns true if there is another row and false if no more rows are
|
||||||
// row and false if no more rows are available or a fatal error has occurred.
|
// available or a fatal error has occurred. It automatically closes rows upon returning false (whether due to all rows
|
||||||
// It automatically closes rows when all rows are read.
|
// having been read or due to an error).
|
||||||
//
|
//
|
||||||
// Callers should check rows.Err() after rows.Next() returns false to detect
|
// Callers should check rows.Err() after rows.Next() returns false to detect whether result-set reading ended
|
||||||
// whether result-set reading ended prematurely due to an error. See
|
// prematurely due to an error. See Conn.Query for details.
|
||||||
// Conn.Query for details.
|
|
||||||
//
|
//
|
||||||
// For simpler error handling, consider using the higher-level pgx v5
|
// For simpler error handling, consider using the higher-level pgx v5 CollectRows() and ForEachRow() helpers instead.
|
||||||
// CollectRows() and ForEachRow() helpers instead.
|
|
||||||
Next() bool
|
Next() bool
|
||||||
|
|
||||||
// Scan reads the values from the current row into dest values positionally.
|
// Scan reads the values from the current row into dest values positionally. dest can include pointers to core types,
|
||||||
// dest can include pointers to core types, values implementing the Scanner
|
// values implementing the Scanner interface, and nil. nil will skip the value entirely. It is an error to call Scan
|
||||||
// interface, and nil. nil will skip the value entirely. It is an error to
|
// without first calling Next() and checking that it returned true. Rows is automatically closed upon error.
|
||||||
// call Scan without first calling Next() and checking that it returned true.
|
|
||||||
Scan(dest ...any) error
|
Scan(dest ...any) error
|
||||||
|
|
||||||
// Values returns the decoded row values. As with Scan(), it is an error to
|
// Values returns the decoded row values. As with Scan(), it is an error to
|
||||||
@ -563,7 +560,7 @@ func (rs *mapRowScanner) ScanRow(rows Rows) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number a public fields as row
|
// RowToStructByPos returns a T scanned from row. T must be a struct. T must have the same number of public fields as row
|
||||||
// has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then the field will be
|
// has fields. The row and T fields will be matched by position. If the "db" struct tag is "-" then the field will be
|
||||||
// ignored.
|
// ignored.
|
||||||
func RowToStructByPos[T any](row CollectableRow) (T, error) {
|
func RowToStructByPos[T any](row CollectableRow) (T, error) {
|
||||||
|
@ -471,7 +471,8 @@ func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.Nam
|
|||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
|
|
||||||
args := namedValueToInterface(argsV)
|
args := make([]any, len(argsV))
|
||||||
|
convertNamedArguments(args, argsV)
|
||||||
|
|
||||||
commandTag, err := c.conn.Exec(ctx, query, args...)
|
commandTag, err := c.conn.Exec(ctx, query, args...)
|
||||||
// if we got a network error before we had a chance to send the query, retry
|
// if we got a network error before we had a chance to send the query, retry
|
||||||
@ -488,8 +489,9 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
|
|||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []any{databaseSQLResultFormats}
|
args := make([]any, 1+len(argsV))
|
||||||
args = append(args, namedValueToInterface(argsV)...)
|
args[0] = databaseSQLResultFormats
|
||||||
|
convertNamedArguments(args[1:], argsV)
|
||||||
|
|
||||||
rows, err := c.conn.Query(ctx, query, args...)
|
rows, err := c.conn.Query(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -848,28 +850,14 @@ func (r *Rows) Next(dest []driver.Value) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func valueToInterface(argsV []driver.Value) []any {
|
func convertNamedArguments(args []any, argsV []driver.NamedValue) {
|
||||||
args := make([]any, 0, len(argsV))
|
for i, v := range argsV {
|
||||||
for _, v := range argsV {
|
|
||||||
if v != nil {
|
|
||||||
args = append(args, v.(any))
|
|
||||||
} else {
|
|
||||||
args = append(args, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return args
|
|
||||||
}
|
|
||||||
|
|
||||||
func namedValueToInterface(argsV []driver.NamedValue) []any {
|
|
||||||
args := make([]any, 0, len(argsV))
|
|
||||||
for _, v := range argsV {
|
|
||||||
if v.Value != nil {
|
if v.Value != nil {
|
||||||
args = append(args, v.Value.(any))
|
args[i] = v.Value.(any)
|
||||||
} else {
|
} else {
|
||||||
args = append(args, nil)
|
args[i] = nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return args
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type wrapTx struct {
|
type wrapTx struct {
|
||||||
|
@ -161,7 +161,6 @@ func writeEncryptedPrivateKey(path string, privateKey *rsa.PrivateKey, password
|
|||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeCertificate(path string, certBytes []byte) error {
|
func writeCertificate(path string, certBytes []byte) error {
|
||||||
|
@ -104,7 +104,7 @@ func logQueryArgs(args []any) []any {
|
|||||||
}
|
}
|
||||||
case string:
|
case string:
|
||||||
if len(v) > 64 {
|
if len(v) > 64 {
|
||||||
var l = 0
|
l := 0
|
||||||
for w := 0; l < 64; l += w {
|
for w := 0; l < 64; l += w {
|
||||||
_, w = utf8.DecodeRuneInString(v[l:])
|
_, w = utf8.DecodeRuneInString(v[l:])
|
||||||
}
|
}
|
||||||
|
@ -362,7 +362,6 @@ func TestLogBatchStatementsOnExec(t *testing.T) {
|
|||||||
assert.Equal(t, "BatchQuery", logger.logs[1].msg)
|
assert.Equal(t, "BatchQuery", logger.logs[1].msg)
|
||||||
assert.Equal(t, "drop table foo", logger.logs[1].data["sql"])
|
assert.Equal(t, "drop table foo", logger.logs[1].data["sql"])
|
||||||
assert.Equal(t, "BatchClose", logger.logs[2].msg)
|
assert.Equal(t, "BatchClose", logger.logs[2].msg)
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,7 +117,6 @@ func TestJSONAndJSONBTranscodeExtendedOnly(t *testing.T) {
|
|||||||
testJSONInt16ArrayFailureDueToOverflow(t, conn, typename)
|
testJSONInt16ArrayFailureDueToOverflow(t, conn, typename)
|
||||||
testJSONStruct(t, conn, typename)
|
testJSONStruct(t, conn, typename)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func testJSONString(t testing.TB, conn *pgx.Conn, typename string) {
|
func testJSONString(t testing.TB, conn *pgx.Conn, typename string) {
|
||||||
@ -597,7 +596,9 @@ func TestArrayDecoding(t *testing.T) {
|
|||||||
assert func(testing.TB, any, any)
|
assert func(testing.TB, any, any)
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"select $1::bool[]", []bool{true, false, true}, &[]bool{},
|
"select $1::bool[]",
|
||||||
|
[]bool{true, false, true},
|
||||||
|
&[]bool{},
|
||||||
func(t testing.TB, query, scan any) {
|
func(t testing.TB, query, scan any) {
|
||||||
if !reflect.DeepEqual(query, *(scan.(*[]bool))) {
|
if !reflect.DeepEqual(query, *(scan.(*[]bool))) {
|
||||||
t.Errorf("failed to encode bool[]")
|
t.Errorf("failed to encode bool[]")
|
||||||
@ -605,7 +606,9 @@ func TestArrayDecoding(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{},
|
"select $1::smallint[]",
|
||||||
|
[]int16{2, 4, 484, 32767},
|
||||||
|
&[]int16{},
|
||||||
func(t testing.TB, query, scan any) {
|
func(t testing.TB, query, scan any) {
|
||||||
if !reflect.DeepEqual(query, *(scan.(*[]int16))) {
|
if !reflect.DeepEqual(query, *(scan.(*[]int16))) {
|
||||||
t.Errorf("failed to encode smallint[]")
|
t.Errorf("failed to encode smallint[]")
|
||||||
@ -613,7 +616,9 @@ func TestArrayDecoding(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{},
|
"select $1::smallint[]",
|
||||||
|
[]uint16{2, 4, 484, 32767},
|
||||||
|
&[]uint16{},
|
||||||
func(t testing.TB, query, scan any) {
|
func(t testing.TB, query, scan any) {
|
||||||
if !reflect.DeepEqual(query, *(scan.(*[]uint16))) {
|
if !reflect.DeepEqual(query, *(scan.(*[]uint16))) {
|
||||||
t.Errorf("failed to encode smallint[]")
|
t.Errorf("failed to encode smallint[]")
|
||||||
@ -621,7 +626,9 @@ func TestArrayDecoding(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"select $1::int[]", []int32{2, 4, 484}, &[]int32{},
|
"select $1::int[]",
|
||||||
|
[]int32{2, 4, 484},
|
||||||
|
&[]int32{},
|
||||||
func(t testing.TB, query, scan any) {
|
func(t testing.TB, query, scan any) {
|
||||||
if !reflect.DeepEqual(query, *(scan.(*[]int32))) {
|
if !reflect.DeepEqual(query, *(scan.(*[]int32))) {
|
||||||
t.Errorf("failed to encode int[]")
|
t.Errorf("failed to encode int[]")
|
||||||
@ -629,7 +636,9 @@ func TestArrayDecoding(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{},
|
"select $1::int[]",
|
||||||
|
[]uint32{2, 4, 484, 2147483647},
|
||||||
|
&[]uint32{},
|
||||||
func(t testing.TB, query, scan any) {
|
func(t testing.TB, query, scan any) {
|
||||||
if !reflect.DeepEqual(query, *(scan.(*[]uint32))) {
|
if !reflect.DeepEqual(query, *(scan.(*[]uint32))) {
|
||||||
t.Errorf("failed to encode int[]")
|
t.Errorf("failed to encode int[]")
|
||||||
@ -637,7 +646,9 @@ func TestArrayDecoding(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{},
|
"select $1::bigint[]",
|
||||||
|
[]int64{2, 4, 484, 9223372036854775807},
|
||||||
|
&[]int64{},
|
||||||
func(t testing.TB, query, scan any) {
|
func(t testing.TB, query, scan any) {
|
||||||
if !reflect.DeepEqual(query, *(scan.(*[]int64))) {
|
if !reflect.DeepEqual(query, *(scan.(*[]int64))) {
|
||||||
t.Errorf("failed to encode bigint[]")
|
t.Errorf("failed to encode bigint[]")
|
||||||
@ -645,7 +656,9 @@ func TestArrayDecoding(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{},
|
"select $1::bigint[]",
|
||||||
|
[]uint64{2, 4, 484, 9223372036854775807},
|
||||||
|
&[]uint64{},
|
||||||
func(t testing.TB, query, scan any) {
|
func(t testing.TB, query, scan any) {
|
||||||
if !reflect.DeepEqual(query, *(scan.(*[]uint64))) {
|
if !reflect.DeepEqual(query, *(scan.(*[]uint64))) {
|
||||||
t.Errorf("failed to encode bigint[]")
|
t.Errorf("failed to encode bigint[]")
|
||||||
@ -653,7 +666,9 @@ func TestArrayDecoding(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{},
|
"select $1::text[]",
|
||||||
|
[]string{"it's", "over", "9000!"},
|
||||||
|
&[]string{},
|
||||||
func(t testing.TB, query, scan any) {
|
func(t testing.TB, query, scan any) {
|
||||||
if !reflect.DeepEqual(query, *(scan.(*[]string))) {
|
if !reflect.DeepEqual(query, *(scan.(*[]string))) {
|
||||||
t.Errorf("failed to encode text[]")
|
t.Errorf("failed to encode text[]")
|
||||||
@ -661,7 +676,9 @@ func TestArrayDecoding(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{},
|
"select $1::timestamptz[]",
|
||||||
|
[]time.Time{time.Unix(323232, 0), time.Unix(3239949334, 0o0)},
|
||||||
|
&[]time.Time{},
|
||||||
func(t testing.TB, query, scan any) {
|
func(t testing.TB, query, scan any) {
|
||||||
queryTimeSlice := query.([]time.Time)
|
queryTimeSlice := query.([]time.Time)
|
||||||
scanTimeSlice := *(scan.(*[]time.Time))
|
scanTimeSlice := *(scan.(*[]time.Time))
|
||||||
@ -672,7 +689,9 @@ func TestArrayDecoding(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{},
|
"select $1::bytea[]",
|
||||||
|
[][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}},
|
||||||
|
&[][]byte{},
|
||||||
func(t testing.TB, query, scan any) {
|
func(t testing.TB, query, scan any) {
|
||||||
queryBytesSliceSlice := query.([][]byte)
|
queryBytesSliceSlice := query.([][]byte)
|
||||||
scanBytesSliceSlice := *(scan.(*[][]byte))
|
scanBytesSliceSlice := *(scan.(*[][]byte))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user