diff --git a/hw05_parallel_execution/misc.go b/hw05_parallel_execution/misc.go new file mode 100644 index 0000000..6197dc3 --- /dev/null +++ b/hw05_parallel_execution/misc.go @@ -0,0 +1,14 @@ +package hw05_parallel_execution //nolint:golint,stylecheck + +import ( + "errors" + "sync" +) + +var ErrErrorsLimitExceeded = errors.New("errors limit exceeded") + +type Errors struct { + count int + mx sync.RWMutex +} +type Task func() error diff --git a/hw05_parallel_execution/mixed.go b/hw05_parallel_execution/mixed.go new file mode 100644 index 0000000..6bb35c8 --- /dev/null +++ b/hw05_parallel_execution/mixed.go @@ -0,0 +1,48 @@ +package hw05_parallel_execution //nolint:golint,stylecheck + +import ( + "sync" +) + +func RunMixed(tasks []Task, n int, m int) error { + ch := make(chan Task) + wg := sync.WaitGroup{} + errs := Errors{} + errs.count = m + var ignore bool + if errs.count < 0 { + ignore = true + } + + wg.Add(n) + for i := 0; i < n; i++ { + go consMixed(ch, &wg, &errs, ignore) + } + + for _, task := range tasks { + if !ignore && errs.count <= 0 { + break + } + ch <- task + } + close(ch) + wg.Wait() + + if errs.count <= 0 && !ignore { + return ErrErrorsLimitExceeded + } + return nil +} + +func consMixed(ch <-chan Task, wg *sync.WaitGroup, errs *Errors, ignore bool) { + defer wg.Done() + for task := range ch { + if !ignore && errs.count <= 0 { + return + } + err := task() + if !ignore && err != nil { + errs.count-- + } + } +} diff --git a/hw05_parallel_execution/mixed_test.go b/hw05_parallel_execution/mixed_test.go new file mode 100644 index 0000000..0a4d15e --- /dev/null +++ b/hw05_parallel_execution/mixed_test.go @@ -0,0 +1,103 @@ +package hw05_parallel_execution //nolint:golint,stylecheck + +import ( + "fmt" + "math/rand" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/goleak" +) + +func TestRunMixed(t *testing.T) { + defer goleak.VerifyNone(t) + + t.Run("50 tasks in 10 goroutines with 100% errors and maxErrors=23 should run not more N+M (=33) tasks", func(t *testing.T) { + tasksCount := 50 + tasks := make([]Task, 0, tasksCount) + + var runTasksCount int32 + + for i := 0; i < tasksCount; i++ { + err := fmt.Errorf("error from task %d", i) + tasks = append(tasks, func() error { + time.Sleep(time.Millisecond * time.Duration(rand.Intn(100))) + atomic.AddInt32(&runTasksCount, 1) + return err + }) + } + + workersCount := 10 + maxErrorsCount := 23 + result := RunMixed(tasks, workersCount, maxErrorsCount) + + require.Equal(t, ErrErrorsLimitExceeded, result) + require.LessOrEqual(t, runTasksCount, int32(workersCount+maxErrorsCount), "extra tasks were started") + }) + + t.Run("50 tasks in 5 goroutines with 50% errors and m<0 should ignore errores", func(t *testing.T) { + tasksCount := 50 + tasks := make([]Task, 0, tasksCount) + + var runTasksCount int32 + var sumTime time.Duration + + for i := 0; i < tasksCount; i++ { + err := fmt.Errorf("error from task %d", i) + taskSleep := time.Millisecond * time.Duration(rand.Intn(100)) + sumTime += taskSleep + + tasks = append(tasks, func() error { + time.Sleep(taskSleep) + atomic.AddInt32(&runTasksCount, 1) + if i%2 == 0 { + return err + } + return nil + }) + } + + workersCount := 5 + maxErrorsCount := -1 + + start := time.Now() + result := RunMixed(tasks, workersCount, maxErrorsCount) + elapsedTime := time.Since(start) + require.Nil(t, result) + + require.Equal(t, runTasksCount, int32(tasksCount), "not all tasks were completed") + require.LessOrEqual(t, int64(elapsedTime), int64(sumTime/2), "tasks were run sequentially?") + }) + + t.Run("143 tasks in 7 goroutines without errors", func(t *testing.T) { + tasksCount := 143 + tasks := make([]Task, 0, tasksCount) + + var runTasksCount int32 + var sumTime time.Duration + + for i := 0; i < tasksCount; i++ { + taskSleep := time.Millisecond * time.Duration(rand.Intn(100)) + sumTime += taskSleep + + tasks = append(tasks, func() error { + time.Sleep(taskSleep) + atomic.AddInt32(&runTasksCount, 1) + return nil + }) + } + + workersCount := 7 + maxErrorsCount := 1 + + start := time.Now() + result := RunMixed(tasks, workersCount, maxErrorsCount) + elapsedTime := time.Since(start) + require.Nil(t, result) + + require.Equal(t, runTasksCount, int32(tasksCount), "not all tasks were completed") + require.LessOrEqual(t, int64(elapsedTime), int64(sumTime/2), "tasks were run sequentially?") + }) +} diff --git a/hw05_parallel_execution/run.go b/hw05_parallel_execution/run.go index b603505..52cf1fb 100644 --- a/hw05_parallel_execution/run.go +++ b/hw05_parallel_execution/run.go @@ -1,65 +1,27 @@ package hw05_parallel_execution //nolint:golint,stylecheck -import ( - "errors" - "sync" -) - -var ErrErrorsLimitExceeded = errors.New("errors limit exceeded") - -type Errors struct { - count int - mx sync.RWMutex -} -type Task func() error - func Run(tasks []Task, n int, m int) error { - ch := make(chan Task) - wg := sync.WaitGroup{} - errs := Errors{} - errs.count = m - var ignore bool - if errs.count < 0 { - ignore = true + if m == -1 { + m = len(tasks) } - - wg.Add(n) - for i := 0; i < n; i++ { - go consumer(ch, &wg, &errs, ignore) - } - + pool := make(chan int, n) + errs := make(chan int) for _, task := range tasks { - errs.mx.RLock() - if !ignore && errs.count <= 0 { - errs.mx.RUnlock() - break + pool <- 1 + go func(task Task, errs chan int, pool chan int) { + if _, ok := <-errs; !ok { + return + } + if task() != nil { + errs <- 1 + } + <-pool + }(task, errs, pool) + if len(errs) >= m { + close(errs) + return ErrErrorsLimitExceeded } - errs.mx.RUnlock() - ch <- task - } - close(ch) - wg.Wait() - - if errs.count <= 0 && !ignore { - return ErrErrorsLimitExceeded } + close(errs) return nil } - -func consumer(ch <-chan Task, wg *sync.WaitGroup, errs *Errors, ignore bool) { - defer wg.Done() - for task := range ch { - errs.mx.RLock() - if !ignore && errs.count <= 0 { - errs.mx.RUnlock() - return - } - errs.mx.RUnlock() - err := task() - if !ignore && err != nil { - errs.mx.Lock() - errs.count-- - errs.mx.Unlock() - } - } -} diff --git a/hw05_parallel_execution/run_test.go b/hw05_parallel_execution/run_test.go index 37e3e3c..e6550f3 100644 --- a/hw05_parallel_execution/run_test.go +++ b/hw05_parallel_execution/run_test.go @@ -31,8 +31,8 @@ func TestRun(t *testing.T) { workersCount := 10 maxErrorsCount := 23 - result := Run(tasks, workersCount, maxErrorsCount) + result := Run(tasks, workersCount, maxErrorsCount) require.Equal(t, ErrErrorsLimitExceeded, result) require.LessOrEqual(t, runTasksCount, int32(workersCount+maxErrorsCount), "extra tasks were started") })