mirror of
https://github.com/jackc/pgx.git
synced 2025-05-22 07:20:35 +00:00
Use DefaultQueryExecMode in CopyFrom
CopyFrom had to create a prepared statement to get the OIDs of the data types that were going to be copied into the table. Every COPY operation required an extra round trips to retrieve the type information. There was no way to customize this behavior. By leveraging the QueryExecMode feature, like in `Conn.Query`, users can specify if they want to cache the prepared statements, execute them on every request (like the old behavior), or bypass the prepared statement relying on the pgtype.Map to get the type information. The `QueryExecMode` behave exactly like in `Conn.Query` in the way the data type OIDs are fetched, meaning that: - `QueryExecModeCacheStatement`: caches the statement. - `QueryExecModeCacheDescribe`: caches the statement and assumes they do not change. - `QueryExecModeDescribeExec`: gets the statement description on every execution. This is like to the old behavior of `CopyFrom`. - `QueryExecModeExec` and `QueryExecModeSimpleProtocol`: maintain the same behavior as before, which is the same as `QueryExecModeDescribeExec`. It will keep getting the statement description on every execution The `QueryExecMode` can only be set via `ConnConfig.DefaultQueryExecMode`, unlike `Conn.Query` there's no support for specifying the `QueryExecMode` via optional arguments in the function signature.
This commit is contained in:
parent
456a242f5c
commit
c4ac6d810f
2
.gitignore
vendored
2
.gitignore
vendored
@ -23,3 +23,5 @@ _testmain.go
|
||||
|
||||
.envrc
|
||||
/.testdb
|
||||
|
||||
.DS_Store
|
||||
|
77
conn.go
77
conn.go
@ -721,44 +721,11 @@ optionLoop:
|
||||
sd, explicitPreparedStatement := c.preparedStatements[sql]
|
||||
if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec {
|
||||
if sd == nil {
|
||||
switch mode {
|
||||
case QueryExecModeCacheStatement:
|
||||
if c.statementCache == nil {
|
||||
err = errDisabledStatementCache
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
sd = c.statementCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
|
||||
sd, err = c.getStatementDescription(ctx, mode, sql)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
c.statementCache.Put(sd)
|
||||
}
|
||||
case QueryExecModeCacheDescribe:
|
||||
if c.descriptionCache == nil {
|
||||
err = errDisabledDescriptionCache
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
sd = c.descriptionCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, "", sql)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
c.descriptionCache.Put(sd)
|
||||
}
|
||||
case QueryExecModeDescribeExec:
|
||||
sd, err = c.Prepare(ctx, "", sql)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(sd.ParamOIDs) != len(args) {
|
||||
@ -827,6 +794,48 @@ optionLoop:
|
||||
return rows, rows.err
|
||||
}
|
||||
|
||||
// getStatementDescription returns the statement description of the sql query
|
||||
// according to the given mode.
|
||||
//
|
||||
// If the mode is one that doesn't require to know the param and result OIDs
|
||||
// then nil is returned without error.
|
||||
func (c *Conn) getStatementDescription(
|
||||
ctx context.Context,
|
||||
mode QueryExecMode,
|
||||
sql string,
|
||||
) (sd *pgconn.StatementDescription, err error) {
|
||||
|
||||
switch mode {
|
||||
case QueryExecModeCacheStatement:
|
||||
if c.statementCache == nil {
|
||||
return nil, errDisabledStatementCache
|
||||
}
|
||||
sd = c.statementCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.statementCache.Put(sd)
|
||||
}
|
||||
case QueryExecModeCacheDescribe:
|
||||
if c.descriptionCache == nil {
|
||||
return nil, errDisabledDescriptionCache
|
||||
}
|
||||
sd = c.descriptionCache.Get(sql)
|
||||
if sd == nil {
|
||||
sd, err = c.Prepare(ctx, "", sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.descriptionCache.Put(sd)
|
||||
}
|
||||
case QueryExecModeDescribeExec:
|
||||
return c.Prepare(ctx, "", sql)
|
||||
}
|
||||
return sd, err
|
||||
}
|
||||
|
||||
// QueryRow is a convenience wrapper over Query. Any error that occurs while
|
||||
// querying is deferred until calling Scan on the returned Row. That Row will
|
||||
// error with ErrNoRows if no rows are returned.
|
||||
|
26
copy_from.go
26
copy_from.go
@ -85,6 +85,7 @@ type copyFrom struct {
|
||||
columnNames []string
|
||||
rowSrc CopyFromSource
|
||||
readerErrChan chan error
|
||||
mode QueryExecMode
|
||||
}
|
||||
|
||||
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
||||
@ -105,9 +106,29 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
||||
}
|
||||
quotedColumnNames := cbuf.String()
|
||||
|
||||
sd, err := ct.conn.Prepare(ctx, "", fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName))
|
||||
var sd *pgconn.StatementDescription
|
||||
switch ct.mode {
|
||||
case QueryExecModeExec, QueryExecModeSimpleProtocol:
|
||||
// These modes don't support the binary format. Before the inclusion of the
|
||||
// QueryExecModes, Conn.Prepare was called on every COPY operation to get
|
||||
// the OIDs. These prepared statements were not cached.
|
||||
//
|
||||
// Since that's the same behavior provided by QueryExecModeDescribeExec,
|
||||
// we'll default to that mode.
|
||||
ct.mode = QueryExecModeDescribeExec
|
||||
fallthrough
|
||||
case QueryExecModeCacheStatement, QueryExecModeCacheDescribe, QueryExecModeDescribeExec:
|
||||
var err error
|
||||
sd, err = ct.conn.getStatementDescription(
|
||||
ctx,
|
||||
ct.mode,
|
||||
fmt.Sprintf("select %s from %s", quotedColumnNames, quotedTableName),
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return 0, fmt.Errorf("statement description failed: %w", err)
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown QueryExecMode: %v", ct.mode)
|
||||
}
|
||||
|
||||
r, w := io.Pipe()
|
||||
@ -208,6 +229,7 @@ func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames [
|
||||
columnNames: columnNames,
|
||||
rowSrc: rowSrc,
|
||||
readerErrChan: make(chan error),
|
||||
mode: c.config.DefaultQueryExecMode,
|
||||
}
|
||||
|
||||
return ct.run(ctx)
|
||||
|
@ -14,6 +14,129 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConnCopyWithAllQueryExecModes(t *testing.T) {
|
||||
for _, mode := range pgxtest.AllQueryExecModes {
|
||||
t.Run(mode.String(), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
cfg.DefaultQueryExecMode = mode
|
||||
conn := mustConnect(t, cfg)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int2,
|
||||
b int4,
|
||||
c int8,
|
||||
d text,
|
||||
e timestamptz
|
||||
)`)
|
||||
|
||||
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
||||
|
||||
inputRows := [][]any{
|
||||
{int16(0), int32(1), int64(2), "abc", tzedTime},
|
||||
{nil, nil, nil, nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e"}, pgx.CopyFromRows(inputRows))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
}
|
||||
if int(copyCount) != len(inputRows) {
|
||||
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query(context.Background(), "select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]any
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnCopyWithKnownOIDQueryExecModes(t *testing.T) {
|
||||
|
||||
for _, mode := range pgxtest.KnownOIDQueryExecModes {
|
||||
t.Run(mode.String(), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
cfg.DefaultQueryExecMode = mode
|
||||
conn := mustConnect(t, cfg)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int2,
|
||||
b int4,
|
||||
c int8,
|
||||
d varchar,
|
||||
e text,
|
||||
f date,
|
||||
g timestamptz
|
||||
)`)
|
||||
|
||||
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
||||
|
||||
inputRows := [][]any{
|
||||
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
|
||||
{nil, nil, nil, nil, nil, nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
}
|
||||
if int(copyCount) != len(inputRows) {
|
||||
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
rows, err := conn.Query(context.Background(), "select * from foo")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for Query: %v", err)
|
||||
}
|
||||
|
||||
var outputRows [][]any
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
t.Errorf("Unexpected error for rows.Err(): %v", rows.Err())
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(inputRows, outputRows) {
|
||||
t.Errorf("Input rows and output rows do not equal: %v -> %v", inputRows, outputRows)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnCopyFromSmall(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -220,7 +343,7 @@ func TestConnCopyFromEnum(t *testing.T) {
|
||||
conn.TypeMap().RegisterType(typ)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `create table foo(
|
||||
_, err = tx.Exec(ctx, `create temporary table foo(
|
||||
a text,
|
||||
b color,
|
||||
c fruit,
|
||||
|
@ -18,6 +18,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
"github.com/jackc/pgx/v5/internal/pgmock"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
@ -1666,6 +1668,59 @@ func TestConnCopyFrom(t *testing.T) {
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnCopyFromBinary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer closeConn(t, pgConn)
|
||||
|
||||
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
|
||||
a int4,
|
||||
b varchar
|
||||
)`).ReadAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := []byte{}
|
||||
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
buf = pgio.AppendInt32(buf, 0)
|
||||
|
||||
inputRows := [][][]byte{}
|
||||
for i := 0; i < 1000; i++ {
|
||||
// Number of elements in the tuple
|
||||
buf = pgio.AppendInt16(buf, int16(2))
|
||||
a := i
|
||||
|
||||
// Length of element for column `a int4`
|
||||
buf = pgio.AppendInt32(buf, 4)
|
||||
buf, err = pgtype.NewMap().Encode(pgtype.Int4OID, pgx.BinaryFormatCode, a, buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
b := "foo " + strconv.Itoa(a) + " bar"
|
||||
lenB := int32(len([]byte(b)))
|
||||
// Length of element for column `b varchar`
|
||||
buf = pgio.AppendInt32(buf, lenB)
|
||||
buf, err = pgtype.NewMap().Encode(pgtype.VarcharOID, pgx.BinaryFormatCode, b, buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
inputRows = append(inputRows, [][]byte{[]byte(strconv.Itoa(a)), []byte(b)})
|
||||
}
|
||||
|
||||
srcBuf := &bytes.Buffer{}
|
||||
srcBuf.Write(buf)
|
||||
ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo (a, b) FROM STDIN BINARY;")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
|
||||
|
||||
result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read()
|
||||
require.NoError(t, result.Err)
|
||||
|
||||
assert.Equal(t, inputRows, result.Rows)
|
||||
|
||||
ensureConnValid(t, pgConn)
|
||||
}
|
||||
|
||||
func TestConnCopyFromCanceled(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -217,7 +217,7 @@ func TestLogCopyFrom(t *testing.T) {
|
||||
return config
|
||||
}
|
||||
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
pgxtest.RunWithQueryExecModes(context.Background(), t, ctr, pgxtest.KnownOIDQueryExecModes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
||||
_, err := conn.Exec(context.Background(), `create temporary table foo(a int4)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user