diff --git a/copy_from.go b/copy_from.go index b924412d..afa80a1d 100644 --- a/copy_from.go +++ b/copy_from.go @@ -75,8 +75,11 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { } r, w := io.Pipe() + doneChan := make(chan struct{}) go func() { + defer close(doneChan) + // Purposely NOT using defer w.Close(). See https://github.com/golang/go/issues/24283. buf := ct.conn.wbuf @@ -114,6 +117,9 @@ func (ct *copyFrom) run(ctx context.Context) (int64, error) { commandTag, err := ct.conn.pgConn.CopyFrom(ctx, r, fmt.Sprintf("copy %s ( %s ) from stdin binary;", quotedTableName, quotedColumnNames)) + r.Close() + <-doneChan + return commandTag.RowsAffected(), err } diff --git a/copy_from_test.go b/copy_from_test.go index 5b1612ec..e9155d32 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -333,6 +333,52 @@ func TestConnCopyFromFailServerSideMidwayAbortsWithoutWaiting(t *testing.T) { ensureConnValid(t, conn) } +type slowFailRaceSource struct { + count int +} + +func (fs *slowFailRaceSource) Next() bool { + time.Sleep(time.Millisecond) + fs.count++ + return fs.count < 1000 +} + +func (fs *slowFailRaceSource) Values() ([]interface{}, error) { + if fs.count == 500 { + return []interface{}{nil, nil}, nil + } + return []interface{}{1, make([]byte, 1000)}, nil +} + +func (fs *slowFailRaceSource) Err() error { + return nil +} + +func TestConnCopyFromSlowFailRace(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int not null, + b bytea not null + )`) + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a", "b"}, &slowFailRaceSource{}) + if err == nil { + t.Errorf("Expected CopyFrom return error, but it did not") + } + if _, ok := err.(*pgconn.PgError); !ok { + t.Errorf("Expected CopyFrom return pgx.PgError, but instead it returned: %v", err) + } + if copyCount != 0 { + t.Errorf("Expected CopyFrom to return 0 copied rows, but got %d", copyCount) + } + + ensureConnValid(t, conn) +} + func TestConnCopyFromCopyFromSourceErrorMidway(t *testing.T) { t.Parallel()