mirror of https://github.com/jackc/pgx.git
Add CopyFrom to pool
parent
7b1272d254
commit
d93de3fdc7
|
@ -57,7 +57,7 @@ type copyFrom struct {
|
||||||
readerErrChan chan error
|
readerErrChan chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ct *copyFrom) run(ctx context.Context) (int, error) {
|
func (ct *copyFrom) run(ctx context.Context) (int64, error) {
|
||||||
quotedTableName := ct.tableName.Sanitize()
|
quotedTableName := ct.tableName.Sanitize()
|
||||||
cbuf := &bytes.Buffer{}
|
cbuf := &bytes.Buffer{}
|
||||||
for i, cn := range ct.columnNames {
|
for i, cn := range ct.columnNames {
|
||||||
|
@ -113,7 +113,7 @@ func (ct *copyFrom) run(ctx context.Context) (int, error) {
|
||||||
|
|
||||||
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
|
||||||
|
|
||||||
return int(commandTag.RowsAffected()), err
|
return commandTag.RowsAffected(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, error) {
|
func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byte, error) {
|
||||||
|
@ -149,7 +149,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byt
|
||||||
// CopyFrom requires all values use the binary format. Almost all types
|
// CopyFrom requires all values use the binary format. Almost all types
|
||||||
// implemented by pgx use the binary format by default. Types implementing
|
// implemented by pgx use the binary format by default. Types implementing
|
||||||
// Encoder can only be used if they encode to the binary format.
|
// Encoder can only be used if they encode to the binary format.
|
||||||
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
|
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
|
||||||
ct := ©From{
|
ct := ©From{
|
||||||
conn: c,
|
conn: c,
|
||||||
tableName: tableName,
|
tableName: tableName,
|
||||||
|
|
|
@ -39,7 +39,7 @@ func TestConnCopyFromSmall(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||||
}
|
}
|
||||||
if copyCount != len(inputRows) {
|
if int(copyCount) != len(inputRows) {
|
||||||
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,7 +97,7 @@ func TestConnCopyFromLarge(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||||
}
|
}
|
||||||
if copyCount != len(inputRows) {
|
if int(copyCount) != len(inputRows) {
|
||||||
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -152,7 +152,7 @@ func TestConnCopyFromJSON(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
t.Errorf("Unexpected error for CopyFrom: %v", err)
|
||||||
}
|
}
|
||||||
if copyCount != len(inputRows) {
|
if int(copyCount) != len(inputRows) {
|
||||||
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
t.Errorf("Expected CopyFrom to return %d copied rows, but got %d", len(inputRows), copyCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -94,3 +94,42 @@ func testSendBatch(t *testing.T, db sendBatcher) {
|
||||||
err = br.Close()
|
err = br.Close()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type copyFromer interface {
|
||||||
|
CopyFrom(context.Context, pgx.Identifier, []string, pgx.CopyFromSource) (int64, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testCopyFrom(t *testing.T, db interface {
|
||||||
|
execer
|
||||||
|
queryer
|
||||||
|
copyFromer
|
||||||
|
}) {
|
||||||
|
_, err := db.Exec(context.Background(), `create temporary table foo(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
||||||
|
|
||||||
|
inputRows := [][]interface{}{
|
||||||
|
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
|
||||||
|
{nil, nil, nil, nil, nil, nil, nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
copyCount, err := db.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, len(inputRows), copyCount)
|
||||||
|
|
||||||
|
rows, err := db.Query(context.Background(), "select * from foo")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var outputRows [][]interface{}
|
||||||
|
for rows.Next() {
|
||||||
|
row, err := rows.Values()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||||
|
}
|
||||||
|
outputRows = append(outputRows, row)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, rows.Err())
|
||||||
|
assert.Equal(t, inputRows, outputRows)
|
||||||
|
}
|
||||||
|
|
|
@ -65,6 +65,10 @@ func (c *Conn) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
|
||||||
return c.Conn().SendBatch(ctx, b)
|
return c.Conn().SendBatch(ctx, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
|
||||||
|
return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*pgx.Tx, error) {
|
func (c *Conn) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*pgx.Tx, error) {
|
||||||
return c.Conn().Begin(ctx, txOptions)
|
return c.Conn().Begin(ctx, txOptions)
|
||||||
}
|
}
|
||||||
|
|
|
@ -56,3 +56,15 @@ func TestConnSendBatch(t *testing.T) {
|
||||||
|
|
||||||
testSendBatch(t, c)
|
testSendBatch(t, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnCopyFrom(t *testing.T) {
|
||||||
|
pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer pool.Close()
|
||||||
|
|
||||||
|
c, err := pool.Acquire(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer c.Release()
|
||||||
|
|
||||||
|
testCopyFrom(t, c)
|
||||||
|
}
|
||||||
|
|
10
pool/pool.go
10
pool/pool.go
|
@ -150,3 +150,13 @@ func (p *Pool) Begin(ctx context.Context, txOptions *pgx.TxOptions) (*Tx, error)
|
||||||
|
|
||||||
return &Tx{t: t, c: c}, err
|
return &Tx{t: t, c: c}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *Pool) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
|
||||||
|
c, err := p.Acquire(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer c.Release()
|
||||||
|
|
||||||
|
return c.Conn().CopyFrom(ctx, tableName, columnNames, rowSrc)
|
||||||
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v4"
|
||||||
"github.com/jackc/pgx/v4/pool"
|
"github.com/jackc/pgx/v4/pool"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
@ -103,6 +104,51 @@ func TestPoolSendBatch(t *testing.T) {
|
||||||
assert.EqualValues(t, 1, stats.TotalConns())
|
assert.EqualValues(t, 1, stats.TotalConns())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPoolCopyFrom(t *testing.T) {
|
||||||
|
// Not able to use testCopyFrom because it relies on temporary tables and the pool may run subsequent calls under
|
||||||
|
// different connections.
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
pool, err := pool.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer pool.Close()
|
||||||
|
|
||||||
|
_, err = pool.Exec(ctx, `drop table if exists poolcopyfromtest`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = pool.Exec(ctx, `create table poolcopyfromtest(a int2, b int4, c int8, d varchar, e text, f date, g timestamptz)`)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer pool.Exec(ctx, `drop table poolcopyfromtest`)
|
||||||
|
|
||||||
|
tzedTime := time.Date(2010, 2, 3, 4, 5, 6, 0, time.Local)
|
||||||
|
|
||||||
|
inputRows := [][]interface{}{
|
||||||
|
{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime},
|
||||||
|
{nil, nil, nil, nil, nil, nil, nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
copyCount, err := pool.CopyFrom(ctx, pgx.Identifier{"poolcopyfromtest"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, len(inputRows), copyCount)
|
||||||
|
|
||||||
|
rows, err := pool.Query(ctx, "select * from poolcopyfromtest")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var outputRows [][]interface{}
|
||||||
|
for rows.Next() {
|
||||||
|
row, err := rows.Values()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error for rows.Values(): %v", err)
|
||||||
|
}
|
||||||
|
outputRows = append(outputRows, row)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, rows.Err())
|
||||||
|
assert.Equal(t, inputRows, outputRows)
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnReleaseRollsBackFailedTransaction(t *testing.T) {
|
func TestConnReleaseRollsBackFailedTransaction(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
func (p *ConnPool) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error)
|
|
||||||
func (p *ConnPool) Deallocate(name string) (err error)
|
func (p *ConnPool) Deallocate(name string) (err error)
|
||||||
func (p *ConnPool) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error)
|
func (p *ConnPool) Prepare(ctx context.Context, name, sql string) (*PreparedStatement, error)
|
||||||
|
|
||||||
|
|
|
@ -49,3 +49,7 @@ func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx
|
||||||
func (tx *Tx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
|
func (tx *Tx) SendBatch(ctx context.Context, b *pgx.Batch) pgx.BatchResults {
|
||||||
return tx.c.SendBatch(ctx, b)
|
return tx.c.SendBatch(ctx, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {
|
||||||
|
return tx.c.CopyFrom(ctx, tableName, columnNames, rowSrc)
|
||||||
|
}
|
||||||
|
|
|
@ -56,3 +56,15 @@ func TestTxSendBatch(t *testing.T) {
|
||||||
|
|
||||||
testSendBatch(t, tx)
|
testSendBatch(t, tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTxCopyFrom(t *testing.T) {
|
||||||
|
pool, err := pool.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer pool.Close()
|
||||||
|
|
||||||
|
tx, err := pool.Begin(context.Background(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer tx.Rollback(context.Background())
|
||||||
|
|
||||||
|
testCopyFrom(t, tx)
|
||||||
|
}
|
||||||
|
|
2
tx.go
2
tx.go
|
@ -177,7 +177,7 @@ func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row
|
||||||
}
|
}
|
||||||
|
|
||||||
// CopyFrom delegates to the underlying *Conn
|
// CopyFrom delegates to the underlying *Conn
|
||||||
func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
|
func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int64, error) {
|
||||||
if tx.status != TxStatusInProgress {
|
if tx.status != TxStatusInProgress {
|
||||||
return 0, ErrTxClosed
|
return 0, ErrTxClosed
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue