diff --git a/copy_from.go b/copy_from.go index a6536d2d..368d4c33 100644 --- a/copy_from.go +++ b/copy_from.go @@ -57,7 +57,7 @@ type copyFrom struct { readerErrChan chan error } -func (ct *copyFrom) run(ctx context.Context) (int, error) { +func (ct *copyFrom) run(ctx context.Context) (int64, error) { quotedTableName := ct.tableName.Sanitize() cbuf := &bytes.Buffer{} for i, cn := range ct.columnNames { @@ -113,7 +113,7 @@ func (ct *copyFrom) run(ctx context.Context) (int, error) { commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) - return int(commandTag.RowsAffected()), err + return commandTag.RowsAffected(), err } func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, error) { @@ -149,7 +149,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byt // CopyFrom requires all values use the binary format. Almost all types // implemented by pgx use the binary format by default. Types implementing // Encoder can only be used if they encode to the binary format. -func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { +func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { ct := ©From{ conn: c, tableName: tableName, diff --git a/copy_from_test.go b/copy_from_test.go index 1eb3159c..cb45debe 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -39,7 +39,7 @@ func TestConnCopyFromSmall(t *testing.T) { if err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } - if copyCount != len(inputRows) { + if int(copyCount) != len(inputRows) { t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) } @@ -97,7 +97,7 @@ func TestConnCopyFromLarge(t *testing.T) { if err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } - if copyCount != len(inputRows) { + if int(copyCount) != len(inputRows) { t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) } @@ -152,7 +152,7 @@ func TestConnCopyFromJSON(t *testing.T) { if err != nil { t.Errorf("Unexpected error for CopyFrom: %v", err) } - if copyCount != len(inputRows) { + if int(copyCount) != len(inputRows) { t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount) } diff --git a/pool/common_test.go b/pool/common_test.go index d0a8fa4a..586d3679 100644 --- a/pool/common_test.go +++ b/pool/common_test.go @@ -94,3 +94,42 @@ func testSendBatch(t *testing.T, db sendBatcher) { err = br.Close() assert.NoError(t, err) } + +type copyFromer interface { + CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error) +} + +func testCopyFrom(t *testing.T, db interface { + execer + queryer + copyFromer +}) { + _, err := db.Exec(context.Background(), `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`) + require.NoError(t, err) + + tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) + + inputRows := [][]interface{}{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, + {nil, nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := db.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) + assert.NoError(t, err) + assert.EqualValues(t, len(inputRows), copyCount) + + rows, err := db.Query(context.Background(), "select * from foo") + assert.NoError(t, err) + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + assert.NoError(t, rows.Err()) + assert.Equal(t, inputRows, outputRows) +} diff --git a/pool/conn.go b/pool/conn.go index 7ccb00f6..05231790 100644 --- a/pool/conn.go +++ b/pool/conn.go @@ -65,6 +65,10 @@ func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { return c.Conn().SendBatch(ctx, b) } +func (c *Conn) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc) +} + func (c *Conn) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*pgx.Tx, error) { return c.Conn().Begin(ctx, txOptions) } diff --git a/pool/conn_test.go b/pool/conn_test.go index af6cc8cb..4f6d0117 100644 --- a/pool/conn_test.go +++ b/pool/conn_test.go @@ -56,3 +56,15 @@ func TestConnSendBatch(t *testing.T) { testSendBatch(t, c) } + +func TestConnCopyFrom(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + c, err := pool.Acquire(context.Background()) + require.NoError(t, err) + defer c.Release() + + testCopyFrom(t, c) +} diff --git a/pool/pool.go b/pool/pool.go index 993077f1..4fdc5efe 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -150,3 +150,13 @@ func (p *Pool) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*Tx, error) return &Tx{t: t, c: c}, err } + +func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + c, err := p.Acquire(ctx) + if err != nil { + return 0, err + } + defer c.Release() + + return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc) +} diff --git a/pool/pool_test.go b/pool/pool_test.go index 86767103..7e14e18a 100644 --- a/pool/pool_test.go +++ b/pool/pool_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4/pool" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -103,6 +104,51 @@ func TestPoolSendBatch(t *testing.T) { assert.EqualValues(t, 1, stats.TotalConns()) } +func TestPoolCopyFrom(t *testing.T) { + // Not able to use testCopyFrom because it relies on temporary tables and the pool may run subsequent calls under + // different connections. + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + pool, err := pool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + _, err = pool.Exec(ctx, `drop table if exists poolcopyfromtest`) + require.NoError(t, err) + + _, err = pool.Exec(ctx, `create table poolcopyfromtest(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`) + require.NoError(t, err) + defer pool.Exec(ctx, `drop table poolcopyfromtest`) + + tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local) + + inputRows := [][]interface{}{ + {int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime}, + {nil, nil, nil, nil, nil, nil, nil}, + } + + copyCount, err := pool.CopyFrom(ctx, pgx.Identifier{"poolcopyfromtest"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows)) + assert.NoError(t, err) + assert.EqualValues(t, len(inputRows), copyCount) + + rows, err := pool.Query(ctx, "select * from poolcopyfromtest") + assert.NoError(t, err) + + var outputRows [][]interface{} + for rows.Next() { + row, err := rows.Values() + if err != nil { + t.Errorf("Unexpected error for rows.Values(): %v", err) + } + outputRows = append(outputRows, row) + } + + assert.NoError(t, rows.Err()) + assert.Equal(t, inputRows, outputRows) +} + func TestConnReleaseRollsBackFailedTransaction(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/pool/todo.txt b/pool/todo.txt index 3b9d1813..4a4acf37 100644 --- a/pool/todo.txt +++ b/pool/todo.txt @@ -1,4 +1,3 @@ -func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) func (p *ConnPool) Deallocate(name string) (err error) func (p *ConnPool) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error) diff --git a/pool/tx.go b/pool/tx.go index 3f0618e0..379c30e0 100644 --- a/pool/tx.go +++ b/pool/tx.go @@ -49,3 +49,7 @@ func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx func (tx *Tx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults { return tx.c.SendBatch(ctx, b) } + +func (tx *Tx) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) { + return tx.c.CopyFrom(ctx, tableName, columnNames, rowSrc) +} diff --git a/pool/tx_test.go b/pool/tx_test.go index 0ea4b8b6..47b1aa82 100644 --- a/pool/tx_test.go +++ b/pool/tx_test.go @@ -56,3 +56,15 @@ func TestTxSendBatch(t *testing.T) { testSendBatch(t, tx) } + +func TestTxCopyFrom(t *testing.T) { + pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + defer pool.Close() + + tx, err := pool.Begin(context.Background(), nil) + require.NoError(t, err) + defer tx.Rollback(context.Background()) + + testCopyFrom(t, tx) +} diff --git a/tx.go b/tx.go index 45477da0..10cabe6e 100644 --- a/tx.go +++ b/tx.go @@ -177,7 +177,7 @@ func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row } // CopyFrom delegates to the underlying *Conn -func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) { +func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) { if tx.status != TxStatusInProgress { return 0, ErrTxClosed }