// Copyright 2016-2023, Pulumi Corporation.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package diy

import (
	"runtime"
	"sync"

	"github.com/hashicorp/go-multierror"
)

// workerPool is a worker pool that runs tasks in parallel.
// It's similar to errgroup.Group with the following differences:
//
//   - re-uses goroutines between tasks instead of spawning them anew
//   - allows multiple Wait calls separately from Close
//   - does not stop on the first error, recording them all instead
//
// This makes it a better fit for phased operations
// where we want to run all tasks to completion and report all errors
// and then run more tasks if those succeed
// without the overhead of spawning new goroutines.
//
// Typical usage is:
//
//	pool := newWorkerPool(numWorkers)
//	defer pool.Close()
//
//	for _, x := range items {
//		x := x // capture loop variable
//		pool.Enqueue(func() error {
//			...
//			return err
//		})
//	}
//	if err := pool.Wait(); err != nil {
//		...
//	}
//
//	// Add more tasks:
//	pool.Enqueue(func() error { ... })
//	...
//	if err := pool.Wait(); err != nil {
//		...
//	}
type workerPool struct {
	// Target number of workers.
	numWorkers int

	// Tracks individual workers.
	//
	// When this hits zero, all workers have exited and Close can return.
	workers sync.WaitGroup

	// Tracks ongoing tasks.
	//
	// When this hits zero, we no longer have any work in flight,
	// so Wait can return.
	ongoing sync.WaitGroup

	// Enqueues tasks for workers.
	tasks chan func() error

	// Errors encountered while running tasks.
	errs  []error
	errMu sync.Mutex // guards errs
}

// newWorkerPool creates a new worker pool with the given number of workers.
// If numWorkers is 0, it defaults to GOMAXPROCS.
//
// numTasksHint is a hint for the maximum number of tasks that will be enqueued
// before a Wait call. It's used to avoid overallocating workers.
// It may be 0 if the number of tasks is unknown.
func newWorkerPool(numWorkers, numTasksHint int) *workerPool {
	if numWorkers <= 0 {
		numWorkers = runtime.GOMAXPROCS(0)
	}

	// If we have a task hint, use it to avoid overallocating workers.
	if numTasksHint > 0 && numTasksHint < numWorkers {
		numWorkers = numTasksHint
	}

	p := &workerPool{
		numWorkers: numWorkers,
		// We use an unbuffered channel.
		// This blocks Enqueue until a worker is ready
		// to accept a task.
		// This keeps usage patterns simple:
		// a few Enqueue calls followed by a Wait call.
		// It discourages attempts to Enqueue from inside a worker
		// because that would deadlock.
		tasks: make(chan func() error),
	}

	p.workers.Add(numWorkers)
	for i := 0; i < numWorkers; i++ {
		go p.worker()
	}

	return p
}

func (p *workerPool) worker() {
	defer p.workers.Done()

	for task := range p.tasks {
		if err := task(); err != nil {
			p.errMu.Lock()
			p.errs = append(p.errs, err)
			p.errMu.Unlock()
		}
		p.ongoing.Done()
	}
}

// Enqueue adds a task to the pool.
// It blocks until the task is accepted by a worker.
// DO NOT call Enqueue from inside a worker
// or concurrently with Wait or Close.
//
// If the task returns an error, it can be retrieved from Wait.
func (p *workerPool) Enqueue(task func() error) {
	p.ongoing.Add(1)
	p.tasks <- task
}

// Wait blocks until all enqueued tasks have finished
// and returns all errors encountered while running them.
// The errors may be combined into a multierror.
//
// Consecutive calls to Wait return errors encountered since last Wait.
// Typically, you'll want to call Wait after all Enqueue calls
// and then again after any additional Enqueue calls.
//
//	pool.Enqueue(t1)
//	pool.Enqueue(t2)
//	pool.Enqueue(t3)
//	if err := pool.Wait(); err != nil {
//		// One or more of t1, t2, t3 failed.
//		return err
//	}
//
//	pool.Enqueue(t4)
//	pool.Enqueue(t5)
//	if err := pool.Wait(); err != nil {
//		// One or more of t4, t5 failed.
//		return err
//	}
func (p *workerPool) Wait() error {
	p.ongoing.Wait()

	p.errMu.Lock()
	var errors []error
	errors, p.errs = p.errs, nil
	p.errMu.Unlock()

	switch len(errors) {
	case 0:
		return nil
	case 1:
		return errors[0]
	default:
		return multierror.Append(errors[0], errors[1:]...)
	}
}

// Close stops the pool, and cleans up all resources.
//
// Do not call Enqueue or Wait after Close.
func (p *workerPool) Close() {
	close(p.tasks)
	p.workers.Wait()
}