mirror of https://github.com/jackc/pgx.git
Add CopyFrom to pool
parent
7b1272d254
commit
d93de3fdc7
|
@ -57,7 +57,7 @@ type copyFrom struct {
|
|||
readerErrChan chan error
|
||||
}
|
||||
|
||||
func (ct *copyFrom) run(ctx context.Context) (int, error) {
|
||||
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
||||
quotedTableName := ct.tableName.Sanitize()
|
||||
cbuf := &bytes.Buffer{}
|
||||
for i, cn := range ct.columnNames {
|
||||
|
@ -113,7 +113,7 @@ func (ct *copyFrom) run(ctx context.Context) (int, error) {
|
|||
|
||||
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
||||
|
||||
return int(commandTag.RowsAffected()), err
|
||||
return commandTag.RowsAffected(), err
|
||||
}
|
||||
|
||||
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, error) {
|
||||
|
@ -149,7 +149,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byt
|
|||
// CopyFrom requires all values use the binary format. Almost all types
|
||||
// implemented by pgx use the binary format by default. Types implementing
|
||||
// Encoder can only be used if they encode to the binary format.
|
||||
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
|
||||
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
|
||||
ct := ©From{
|
||||
conn: c,
|
||||
tableName: tableName,
|
||||
|
|
|
@ -39,7 +39,7 @@ func TestConnCopyFromSmall(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
}
|
||||
if copyCount != len(inputRows) {
|
||||
if int(copyCount) != len(inputRows) {
|
||||
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
|
@ -97,7 +97,7 @@ func TestConnCopyFromLarge(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
}
|
||||
if copyCount != len(inputRows) {
|
||||
if int(copyCount) != len(inputRows) {
|
||||
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
|
@ -152,7 +152,7 @@ func TestConnCopyFromJSON(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||
}
|
||||
if copyCount != len(inputRows) {
|
||||
if int(copyCount) != len(inputRows) {
|
||||
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||
}
|
||||
|
||||
|
|
|
@ -94,3 +94,42 @@ func testSendBatch(t *testing.T, db sendBatcher) {
|
|||
err = br.Close()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type copyFromer interface {
|
||||
CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error)
|
||||
}
|
||||
|
||||
func testCopyFrom(t *testing.T, db interface {
|
||||
execer
|
||||
queryer
|
||||
copyFromer
|
||||
}) {
|
||||
_, err := db.Exec(context.Background(), `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`)
|
||||
require.NoError(t, err)
|
||||
|
||||
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
|
||||
{nil, nil, nil, nil, nil, nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := db.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, len(inputRows), copyCount)
|
||||
|
||||
rows, err := db.Query(context.Background(), "select * from foo")
|
||||
assert.NoError(t, err)
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
assert.NoError(t, rows.Err())
|
||||
assert.Equal(t, inputRows, outputRows)
|
||||
}
|
||||
|
|
|
@ -65,6 +65,10 @@ func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
|
|||
return c.Conn().SendBatch(ctx, b)
|
||||
}
|
||||
|
||||
func (c *Conn) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
|
||||
return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc)
|
||||
}
|
||||
|
||||
func (c *Conn) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*pgx.Tx, error) {
|
||||
return c.Conn().Begin(ctx, txOptions)
|
||||
}
|
||||
|
|
|
@ -56,3 +56,15 @@ func TestConnSendBatch(t *testing.T) {
|
|||
|
||||
testSendBatch(t, c)
|
||||
}
|
||||
|
||||
func TestConnCopyFrom(t *testing.T) {
|
||||
pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
c, err := pool.Acquire(context.Background())
|
||||
require.NoError(t, err)
|
||||
defer c.Release()
|
||||
|
||||
testCopyFrom(t, c)
|
||||
}
|
||||
|
|
10
pool/pool.go
10
pool/pool.go
|
@ -150,3 +150,13 @@ func (p *Pool) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*Tx, error)
|
|||
|
||||
return &Tx{t: t, c: c}, err
|
||||
}
|
||||
|
||||
func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
|
||||
c, err := p.Acquire(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer c.Release()
|
||||
|
||||
return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc)
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/pool"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -103,6 +104,51 @@ func TestPoolSendBatch(t *testing.T) {
|
|||
assert.EqualValues(t, 1, stats.TotalConns())
|
||||
}
|
||||
|
||||
func TestPoolCopyFrom(t *testing.T) {
|
||||
// Not able to use testCopyFrom because it relies on temporary tables and the pool may run subsequent calls under
|
||||
// different connections.
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
pool, err := pool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
_, err = pool.Exec(ctx, `drop table if exists poolcopyfromtest`)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = pool.Exec(ctx, `create table poolcopyfromtest(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`)
|
||||
require.NoError(t, err)
|
||||
defer pool.Exec(ctx, `drop table poolcopyfromtest`)
|
||||
|
||||
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
|
||||
{nil, nil, nil, nil, nil, nil, nil},
|
||||
}
|
||||
|
||||
copyCount, err := pool.CopyFrom(ctx, pgx.Identifier{"poolcopyfromtest"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
|
||||
assert.NoError(t, err)
|
||||
assert.EqualValues(t, len(inputRows), copyCount)
|
||||
|
||||
rows, err := pool.Query(ctx, "select * from poolcopyfromtest")
|
||||
assert.NoError(t, err)
|
||||
|
||||
var outputRows [][]interface{}
|
||||
for rows.Next() {
|
||||
row, err := rows.Values()
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||
}
|
||||
outputRows = append(outputRows, row)
|
||||
}
|
||||
|
||||
assert.NoError(t, rows.Err())
|
||||
assert.Equal(t, inputRows, outputRows)
|
||||
}
|
||||
|
||||
func TestConnReleaseRollsBackFailedTransaction(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error)
|
||||
func (p *ConnPool) Deallocate(name string) (err error)
|
||||
func (p *ConnPool) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error)
|
||||
|
||||
|
|
|
@ -49,3 +49,7 @@ func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx
|
|||
func (tx *Tx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
|
||||
return tx.c.SendBatch(ctx, b)
|
||||
}
|
||||
|
||||
func (tx *Tx) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
|
||||
return tx.c.CopyFrom(ctx, tableName, columnNames, rowSrc)
|
||||
}
|
||||
|
|
|
@ -56,3 +56,15 @@ func TestTxSendBatch(t *testing.T) {
|
|||
|
||||
testSendBatch(t, tx)
|
||||
}
|
||||
|
||||
func TestTxCopyFrom(t *testing.T) {
|
||||
pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||
require.NoError(t, err)
|
||||
defer pool.Close()
|
||||
|
||||
tx, err := pool.Begin(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer tx.Rollback(context.Background())
|
||||
|
||||
testCopyFrom(t, tx)
|
||||
}
|
||||
|
|
2
tx.go
2
tx.go
|
@ -177,7 +177,7 @@ func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row
|
|||
}
|
||||
|
||||
// CopyFrom delegates to the underlying *Conn
|
||||
func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
|
||||
func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
|
||||
if tx.status != TxStatusInProgress {
|
||||
return 0, ErrTxClosed
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue