Add CopyFrom to pool

pull/483/head
Jack Christensen 2019-04-25 15:35:53 -05:00
parent 7b1272d254
commit d93de3fdc7
11 changed files with 134 additions and 8 deletions

View File

@ -57,7 +57,7 @@ type copyFrom struct {
readerErrChan chan error
}
func (ct *copyFrom) run(ctx context.Context) (int, error) {
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
quotedTableName := ct.tableName.Sanitize()
cbuf := &bytes.Buffer{}
for i, cn := range ct.columnNames {
@ -113,7 +113,7 @@ func (ct *copyFrom) run(ctx context.Context) (int, error) {
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
return int(commandTag.RowsAffected()), err
return commandTag.RowsAffected(), err
}
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, error) {
@ -149,7 +149,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byt
// CopyFrom requires all values use the binary format. Almost all types
// implemented by pgx use the binary format by default. Types implementing
// Encoder can only be used if they encode to the binary format.
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
ct := &copyFrom{
conn: c,
tableName: tableName,

View File

@ -39,7 +39,7 @@ func TestConnCopyFromSmall(t *testing.T) {
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
if copyCount != len(inputRows) {
if int(copyCount) != len(inputRows) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
}
@ -97,7 +97,7 @@ func TestConnCopyFromLarge(t *testing.T) {
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
if copyCount != len(inputRows) {
if int(copyCount) != len(inputRows) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
}
@ -152,7 +152,7 @@ func TestConnCopyFromJSON(t *testing.T) {
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
if copyCount != len(inputRows) {
if int(copyCount) != len(inputRows) {
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
}

View File

@ -94,3 +94,42 @@ func testSendBatch(t *testing.T, db sendBatcher) {
err = br.Close()
assert.NoError(t, err)
}
type copyFromer interface {
CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error)
}
func testCopyFrom(t *testing.T, db interface {
execer
queryer
copyFromer
}) {
_, err := db.Exec(context.Background(), `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`)
require.NoError(t, err)
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
inputRows := [][]interface{}{
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
{nil, nil, nil, nil, nil, nil, nil},
}
copyCount, err := db.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
assert.NoError(t, err)
assert.EqualValues(t, len(inputRows), copyCount)
rows, err := db.Query(context.Background(), "select * from foo")
assert.NoError(t, err)
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
assert.NoError(t, rows.Err())
assert.Equal(t, inputRows, outputRows)
}

View File

@ -65,6 +65,10 @@ func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
return c.Conn().SendBatch(ctx, b)
}
func (c *Conn) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc)
}
func (c *Conn) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*pgx.Tx, error) {
return c.Conn().Begin(ctx, txOptions)
}

View File

@ -56,3 +56,15 @@ func TestConnSendBatch(t *testing.T) {
testSendBatch(t, c)
}
func TestConnCopyFrom(t *testing.T) {
pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer pool.Close()
c, err := pool.Acquire(context.Background())
require.NoError(t, err)
defer c.Release()
testCopyFrom(t, c)
}

View File

@ -150,3 +150,13 @@ func (p *Pool) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*Tx, error)
return &Tx{t: t, c: c}, err
}
func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
c, err := p.Acquire(ctx)
if err != nil {
return 0, err
}
defer c.Release()
return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc)
}

View File

@ -6,6 +6,7 @@ import (
"testing"
"time"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pool"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -103,6 +104,51 @@ func TestPoolSendBatch(t *testing.T) {
assert.EqualValues(t, 1, stats.TotalConns())
}
func TestPoolCopyFrom(t *testing.T) {
// Not able to use testCopyFrom because it relies on temporary tables and the pool may run subsequent calls under
// different connections.
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
pool, err := pool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer pool.Close()
_, err = pool.Exec(ctx, `drop table if exists poolcopyfromtest`)
require.NoError(t, err)
_, err = pool.Exec(ctx, `create table poolcopyfromtest(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`)
require.NoError(t, err)
defer pool.Exec(ctx, `drop table poolcopyfromtest`)
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
inputRows := [][]interface{}{
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
{nil, nil, nil, nil, nil, nil, nil},
}
copyCount, err := pool.CopyFrom(ctx, pgx.Identifier{"poolcopyfromtest"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
assert.NoError(t, err)
assert.EqualValues(t, len(inputRows), copyCount)
rows, err := pool.Query(ctx, "select * from poolcopyfromtest")
assert.NoError(t, err)
var outputRows [][]interface{}
for rows.Next() {
row, err := rows.Values()
if err != nil {
t.Errorf("Unexpected error for rows.Values(): %v", err)
}
outputRows = append(outputRows, row)
}
assert.NoError(t, rows.Err())
assert.Equal(t, inputRows, outputRows)
}
func TestConnReleaseRollsBackFailedTransaction(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

View File

@ -1,4 +1,3 @@
func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error)
func (p *ConnPool) Deallocate(name string) (err error)
func (p *ConnPool) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error)

View File

@ -49,3 +49,7 @@ func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx
func (tx *Tx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
return tx.c.SendBatch(ctx, b)
}
func (tx *Tx) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
return tx.c.CopyFrom(ctx, tableName, columnNames, rowSrc)
}

View File

@ -56,3 +56,15 @@ func TestTxSendBatch(t *testing.T) {
testSendBatch(t, tx)
}
func TestTxCopyFrom(t *testing.T) {
pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer pool.Close()
tx, err := pool.Begin(context.Background(), nil)
require.NoError(t, err)
defer tx.Rollback(context.Background())
testCopyFrom(t, tx)
}

2
tx.go
View File

@ -177,7 +177,7 @@ func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row
}
// CopyFrom delegates to the underlying *Conn
func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
if tx.status != TxStatusInProgress {
return 0, ErrTxClosed
}