Fix race condition in CopyFrom

In case of an error it was possible for the goroutine that builds the
copy stream to still be running after CopyFrom returned. Since that
goroutine uses the connections ConnInfo data types to encode the copy
data it was possible for those types to be concurrently used in an
unsafe fashion.

CopyFrom will no longer return until that goroutine has completed.
pull/681/head
Jack Christensen 2020-02-14 17:30:44 -06:00
parent 8c9d1cc15b
commit 3b9f79e2f3
2 changed files with 52 additions and 0 deletions

View File

@ -75,8 +75,11 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
}
r, w := io.Pipe()
doneChan := make(chan struct{})
go func() {
defer close(doneChan)
// Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283.
buf := ct.conn.wbuf
@ -114,6 +117,9 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) {
commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames))
r.Close()
<-doneChan
return commandTag.RowsAffected(), err
}

View File

@ -333,6 +333,52 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) {
ensureConnValid(t, conn)
}
type slowFailRaceSource struct {
count int
}
func (fs *slowFailRaceSource) Next() bool {
time.Sleep(time.Millisecond)
fs.count++
return fs.count < 1000
}
func (fs *slowFailRaceSource) Values() ([]interface{}, error) {
if fs.count == 500 {
return []interface{}{nil, nil}, nil
}
return []interface{}{1, make([]byte, 1000)}, nil
}
func (fs *slowFailRaceSource) Err() error {
return nil
}
func TestConnCopyFromSlowFailRace(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, `create temporary table foo(
a int not null,
b bytea not null
)`)
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, &slowFailRaceSource{})
if err == nil {
t.Errorf("Expected CopyFrom return error, but it did not")
}
if _, ok := err.(*pgconn.PgError); !ok {
t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err)
}
if copyCount != 0 {
t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount)
}
ensureConnValid(t, conn)
}
func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) {
t.Parallel()