pulumi/pkg/backend/diy/pool.go

188 lines
4.7 KiB
Go

// Copyright 2016-2023, Pulumi Corporation.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package diy
import (
"runtime"
"sync"
"github.com/hashicorp/go-multierror"
)
// workerPool is a worker pool that runs tasks in parallel.
// It's similar to errgroup.Group with the following differences:
//
// - re-uses goroutines between tasks instead of spawning them anew
// - allows multiple Wait calls separately from Close
// - does not stop on the first error, recording them all instead
//
// This makes it a better fit for phased operations
// where we want to run all tasks to completion and report all errors
// and then run more tasks if those succeed
// without the overhead of spawning new goroutines.
//
// Typical usage is:
//
// pool := newWorkerPool(numWorkers)
// defer pool.Close()
//
// for _, x := range items {
// x := x // capture loop variable
// pool.Enqueue(func() error {
// ...
// return err
// })
// }
// if err := pool.Wait(); err != nil {
// ...
// }
//
// // Add more tasks:
// pool.Enqueue(func() error { ... })
// ...
// if err := pool.Wait(); err != nil {
// ...
// }
type workerPool struct {
// Target number of workers.
numWorkers int
// Tracks individual workers.
//
// When this hits zero, all workers have exited and Close can return.
workers sync.WaitGroup
// Tracks ongoing tasks.
//
// When this hits zero, we no longer have any work in flight,
// so Wait can return.
ongoing sync.WaitGroup
// Enqueues tasks for workers.
tasks chan func() error
// Errors encountered while running tasks.
errs []error
errMu sync.Mutex // guards errs
}
// newWorkerPool creates a new worker pool with the given number of workers.
// If numWorkers is 0, it defaults to GOMAXPROCS.
//
// numTasksHint is a hint for the maximum number of tasks that will be enqueued
// before a Wait call. It's used to avoid overallocating workers.
// It may be 0 if the number of tasks is unknown.
func newWorkerPool(numWorkers, numTasksHint int) *workerPool {
if numWorkers <= 0 {
numWorkers = runtime.GOMAXPROCS(0)
}
// If we have a task hint, use it to avoid overallocating workers.
if numTasksHint > 0 && numTasksHint < numWorkers {
numWorkers = numTasksHint
}
p := &workerPool{
numWorkers: numWorkers,
// We use an unbuffered channel.
// This blocks Enqueue until a worker is ready
// to accept a task.
// This keeps usage patterns simple:
// a few Enqueue calls followed by a Wait call.
// It discourages attempts to Enqueue from inside a worker
// because that would deadlock.
tasks: make(chan func() error),
}
p.workers.Add(numWorkers)
for i := 0; i < numWorkers; i++ {
go p.worker()
}
return p
}
func (p *workerPool) worker() {
defer p.workers.Done()
for task := range p.tasks {
if err := task(); err != nil {
p.errMu.Lock()
p.errs = append(p.errs, err)
p.errMu.Unlock()
}
p.ongoing.Done()
}
}
// Enqueue adds a task to the pool.
// It blocks until the task is accepted by a worker.
// DO NOT call Enqueue from inside a worker
// or concurrently with Wait or Close.
//
// If the task returns an error, it can be retrieved from Wait.
func (p *workerPool) Enqueue(task func() error) {
p.ongoing.Add(1)
p.tasks <- task
}
// Wait blocks until all enqueued tasks have finished
// and returns all errors encountered while running them.
// The errors may be combined into a multierror.
//
// Consecutive calls to Wait return errors encountered since last Wait.
// Typically, you'll want to call Wait after all Enqueue calls
// and then again after any additional Enqueue calls.
//
// pool.Enqueue(t1)
// pool.Enqueue(t2)
// pool.Enqueue(t3)
// if err := pool.Wait(); err != nil {
// // One or more of t1, t2, t3 failed.
// return err
// }
//
// pool.Enqueue(t4)
// pool.Enqueue(t5)
// if err := pool.Wait(); err != nil {
// // One or more of t4, t5 failed.
// return err
// }
func (p *workerPool) Wait() error {
p.ongoing.Wait()
p.errMu.Lock()
var errors []error
errors, p.errs = p.errs, nil
p.errMu.Unlock()
switch len(errors) {
case 0:
return nil
case 1:
return errors[0]
default:
return multierror.Append(errors[0], errors[1:]...)
}
}
// Close stops the pool, and cleans up all resources.
//
// Do not call Enqueue or Wait after Close.
func (p *workerPool) Close() {
close(p.tasks)
p.workers.Wait()
}