From 3b9f79e2f3d6a414a0590e4da3b5a027fd737f0e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 14 Feb 2020 17:30:44 -0600 Subject: [PATCH] Fix race condition in CopyFrom In case of an error it was possible for the goroutine that builds the copy stream to still be running after CopyFrom returned. Since that goroutine uses the connections ConnInfo data types to encode the copy data it was possible for those types to be concurrently used in an unsafe fashion. CopyFrom will no longer return until that goroutine has completed. --- copy_from.go | 6 ++++++ copy_from_test.go | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/copy_from.go b/copy_from.go index b924412d..afa80a1d 100644 --- a/copy_from.go +++ b/copy_from.go @@ -75,8 +75,11 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { } r, w := io.Pipe() + doneChan := make(chan struct{}) go func() { + defer close(doneChan) + // Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283. buf := ct.conn.wbuf @@ -114,6 +117,9 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) + r.Close() + <-doneChan + return commandTag.RowsAffected(), err } diff --git a/copy_from_test.go b/copy_from_test.go index 5b1612ec..e9155d32 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -333,6 +333,52 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { ensureConnValid(t, conn) } +type slowFailRaceSource struct { + count int +} + +func (fs *slowFailRaceSource) Next() bool { + time.Sleep(time.Millisecond) + fs.count++ + return fs.count < 1000 +} + +func (fs *slowFailRaceSource) Values() ([]interface{}, error) { + if fs.count == 500 { + return []interface{}{nil, nil}, nil + } + return []interface{}{1, make([]byte, 1000)}, nil +} + +func (fs *slowFailRaceSource) Err() error { + return nil +} + +func TestConnCopyFromSlowFailRace(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int not null, + b bytea not null + )`) + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, &slowFailRaceSource{}) + if err == nil { + t.Errorf("Expected CopyFrom return error, but it did not") + } + if _, ok := err.(*pgconn.PgError); !ok { + t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) + } + + ensureConnValid(t, conn) +} + func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { t.Parallel()