Conn.CopyFrom takes context

pull/483/head
Jack Christensen 2019-04-20 11:38:23 -05:00
parent 95756b1d7f
commit dc699cefc7
4 changed files with 15 additions and 14 deletions

View File

@ -492,7 +492,8 @@ func benchmarkWriteNRowsViaCopy(b *testing.B, n int) {
for i := 0; i < b.N; i++ {
src := newBenchmarkWriteTableCopyFromSrc(n)
_, err := conn.CopyFrom(pgx.Identifier{"t"},
_, err := conn.CopyFrom(context.Background(),
pgx.Identifier{"t"},
[]string{"varchar_1",
"varchar_2",
"varchar_null_1",

View File

@ -57,7 +57,7 @@ type copyFrom struct {
readerErrChan chan error
}
func (ct *copyFrom) run() (int, error) {
func (ct *copyFrom) run(ctx context.Context) (int, error) {
quotedTableName := ct.tableName.Sanitize()
cbuf := &bytes.Buffer{}
for i, cn := range ct.columnNames {
@ -111,7 +111,7 @@ func (ct *copyFrom) run() (int, error) {
w.Close()
}()
commandTag, err := ct.conn.pgConn.CopyFrom(context.TODO(), r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
return int(commandTag.RowsAffected()), err
}
@ -149,7 +149,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, ps *PreparedStatement) (bool, []byt
// CopyFrom requires all values use the binary format. Almost all types
// implemented by pgx use the binary format by default. Types implementing
// Encoder can only be used if they encode to the binary format.
func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
func (c *Conn) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
ct := &copyFrom{
conn: c,
tableName: tableName,
@ -158,5 +158,5 @@ func (c *Conn) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyF
readerErrChan: make(chan error),
}
return ct.run()
return ct.run(ctx)
}

View File

@ -35,7 +35,7 @@ func TestConnCopyFromSmall(t *testing.T) {
{nil, nil, nil, nil, nil, nil, nil},
}
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g"}, pgx.CopyFromRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
@ -93,7 +93,7 @@ func TestConnCopyFromLarge(t *testing.T) {
inputRows = append(inputRows, []interface{}{int16(0), int32(1), int64(2), "abc", "efg", time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), tzedTime, []byte{111, 111, 111, 111}})
}
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b", "c", "d", "e", "f", "g", "h"}, pgx.CopyFromRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
@ -148,7 +148,7 @@ func TestConnCopyFromJSON(t *testing.T) {
{nil, nil},
}
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
if err != nil {
t.Errorf("Unexpected error for CopyFrom: %v", err)
}
@ -220,7 +220,7 @@ func TestConnCopyFromFailServerSideMidway(t *testing.T) {
{int32(3), "def"},
}
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, pgx.CopyFromRows(inputRows))
if err == nil {
t.Errorf("Expected CopyFrom return error, but it did not")
}
@ -291,7 +291,7 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
startTime := time.Now()
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &failSource{})
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &failSource{})
if err == nil {
t.Errorf("Expected CopyFrom return error, but it did not")
}
@ -343,7 +343,7 @@ func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
a bytea not null
)`)
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{})
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFailSource{})
if err == nil {
t.Errorf("Expected CopyFrom return error, but it did not")
}
@ -403,7 +403,7 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
a bytea not null
)`)
copyCount, err := conn.CopyFrom(pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{})
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, &clientFinalErrSource{})
if err == nil {
t.Errorf("Expected CopyFrom return error, but it did not")
}

4
tx.go
View File

@ -189,12 +189,12 @@ func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row
}
// CopyFrom delegates to the underlying *Conn
func (tx *Tx) CopyFrom(tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
func (tx *Tx) CopyFrom(ctx context.Context, tableName Identifier, columnNames []string, rowSrc CopyFromSource) (int, error) {
if tx.status != TxStatusInProgress {
return 0, ErrTxClosed
}
return tx.conn.CopyFrom(tableName, columnNames, rowSrc)
return tx.conn.CopyFrom(ctx, tableName, columnNames, rowSrc)
}
// Status returns the status of the transaction from the set of