// Copyright 2016-2023, Pulumi Corporation. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cmdutil import ( "bytes" "context" "errors" "fmt" "io" "os/exec" "path/filepath" "runtime" "sync" "testing" "time" ps "github.com/mitchellh/go-ps" "github.com/pulumi/pulumi/sdk/v3/go/common/testing/iotest" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestTerminate_gracefulShutdown(t *testing.T) { t.Parallel() // This test runs commands in a child process, signals them, // and expects them to shutdown gracefully. // // The contract for the child process is as follows: // // - It MUST print something to stdout when it is ready to receive signals. // - It MUST exit with a zero code if it receives a SIGINT. // - It MUST exit with a non-zero code if the signal wasn't received within 3 seconds. // - It MAY print diagnostic messages to stderr. tests := []struct { desc string prog testProgram }{ {desc: "go", prog: goTestProgram.From("graceful.go")}, {desc: "node", prog: nodeTestProgram.From("graceful.js")}, {desc: "python", prog: pythonTestProgram.From("graceful.py")}, {desc: "with child", prog: goTestProgram.From("graceful_with_child.go")}, } for _, tt := range tests { tt := tt t.Run(tt.desc, func(t *testing.T) { t.Parallel() cmd := tt.prog.Build(t) var stdout lockedBuffer cmd.Stdout = io.MultiWriter(&stdout, iotest.LogWriterPrefixed(t, "stdout: ")) cmd.Stderr = iotest.LogWriterPrefixed(t, "stderr: ") require.NoError(t, cmd.Start(), "error starting child process") done := make(chan struct{}) go func() { defer close(done) // Wait until the child process is ready to receive signals. for stdout.Len() == 0 { time.Sleep(10 * time.Millisecond) } ok, err := TerminateProcessGroup(cmd.Process, 1*time.Second) assert.True(t, ok, "child process did not exit gracefully") assert.NoError(t, err, "error terminating child process") }() err := cmd.Wait() if isWaitAlreadyExited(err) { err = nil } assert.NoError(t, err, "child did not exit cleanly") <-done }) } } func TestTerminate_gracefulShutdown_exitError(t *testing.T) { t.Parallel() // This test runs commands in a child process, signals them, // and expects them to shutdown gracefully // but with a non-zero exit code. cmd := goTestProgram.From("graceful.go").Args("-exit-code", "1").Build(t) var stdout lockedBuffer cmd.Stdout = io.MultiWriter(&stdout, iotest.LogWriterPrefixed(t, "stdout: ")) cmd.Stderr = iotest.LogWriterPrefixed(t, "stderr: ") require.NoError(t, cmd.Start(), "error starting child process") // Wait until the child process is ready to receive signals. for stdout.Len() == 0 { time.Sleep(10 * time.Millisecond) } ok, err := TerminateProcessGroup(cmd.Process, 1*time.Second) assert.True(t, ok, "child process did not exit gracefully") require.Error(t, err, "child process must exit with non-zero code") var exitErr *exec.ExitError if assert.ErrorAs(t, err, &exitErr, "expected ExitError from child process") { assert.Equal(t, 1, exitErr.ExitCode(), "unexpected exit code from child process") } } func TestTerminate_forceKill(t *testing.T) { t.Parallel() // This test runs commands in a child process, signals them, // and expects them to not exit in a timely manner. // // The contract for the child process is the same as gracefulShutdown, // except: // // - It MUST freeze for at least 1 second after it receives a SIGINT. // - It MAY exit with a non-zero code if it receives a SIGINT. tests := []struct { desc string prog testProgram }{ {desc: "go", prog: goTestProgram.From("frozen.go")}, {desc: "node", prog: nodeTestProgram.From("frozen.js")}, {desc: "python", prog: pythonTestProgram.From("frozen.py")}, } for _, tt := range tests { tt := tt t.Run(tt.desc, func(t *testing.T) { t.Parallel() cmd := tt.prog.Build(t) var stdout lockedBuffer cmd.Stdout = io.MultiWriter(&stdout, iotest.LogWriterPrefixed(t, "stdout: ")) cmd.Stderr = iotest.LogWriterPrefixed(t, "stderr: ") require.NoError(t, cmd.Start(), "error starting child process") // Wait until the child process is ready to receive signals. for stdout.Len() == 0 { time.Sleep(10 * time.Millisecond) } pid := cmd.Process.Pid done := make(chan struct{}) go func() { defer close(done) ok, err := TerminateProcessGroup(cmd.Process, 50*time.Millisecond) assert.False(t, ok, "child process should not exit gracefully") assert.NoError(t, err, "error terminating child process") }() select { case <-done: // continue case <-time.After(200 * time.Millisecond): // If the process is not killed, // cmd.Wait() will block until it exits. t.Fatal("Took too long to kill child process") } assert.NoError(t, waitPidDead(pid, 100*time.Millisecond), "error waiting for process to die") }) } } func TestTerminate_forceKill_processGroup(t *testing.T) { t.Parallel() // This is a variant of TestTerminate_forceKill // that verifies that a child process of the test process // is also killed. cmd := goTestProgram.From("frozen_with_child.go").Build(t) var stdout lockedBuffer cmd.Stdout = io.MultiWriter(&stdout, iotest.LogWriterPrefixed(t, "stdout: ")) cmd.Stderr = iotest.LogWriterPrefixed(t, "stderr: ") require.NoError(t, cmd.Start(), "error starting child process") // Wait until the child process is ready to receive signals. for stdout.Len() == 0 { time.Sleep(10 * time.Millisecond) } pid := cmd.Process.Pid childPid := -1 procs, err := ps.Processes() require.NoError(t, err, "error listing processes") for _, proc := range procs { if proc.PPid() == pid { childPid = proc.Pid() break } } require.NotEqual(t, -1, childPid, "child process not found") done := make(chan struct{}) go func() { defer close(done) ok, err := TerminateProcessGroup(cmd.Process, time.Millisecond) assert.False(t, ok, "child process should not exit gracefully") assert.NoError(t, err, "error terminating child process") }() select { case <-done: // continue case <-time.After(100 * time.Millisecond): // If the child process is not killed, // cmd.Wait() will block until it exits. t.Fatal("Took too long to kill child process") } for _, pid := range []int{pid, childPid} { assert.NoError(t, waitPidDead(pid, 100*time.Millisecond), "error waiting for process to die") } } func TestTerminate_unhandledInterrupt(t *testing.T) { t.Parallel() // This test runs programs that do not have an interrupt handler. // Contract for child process: // // - It MUST print to stdout when it's ready. // - It MUST exit with a non-zero code if it does not get terminated within 3 seconds. tests := []struct { desc string prog testProgram }{ {desc: "go", prog: goTestProgram.From("unhandled.go")}, {desc: "node", prog: nodeTestProgram.From("unhandled.js")}, {desc: "python", prog: pythonTestProgram.From("unhandled.py")}, } for _, tt := range tests { tt := tt t.Run(tt.desc, func(t *testing.T) { t.Parallel() cmd := tt.prog.Build(t) var stdout lockedBuffer cmd.Stdout = io.MultiWriter(&stdout, iotest.LogWriterPrefixed(t, "stdout: ")) cmd.Stderr = iotest.LogWriterPrefixed(t, "stderr: ") require.NoError(t, cmd.Start(), "error starting child process") // Wait until the child process is ready to receive signals. for stdout.Len() == 0 { time.Sleep(10 * time.Millisecond) } pid := cmd.Process.Pid done := make(chan struct{}) go func() { defer close(done) ok, err := TerminateProcessGroup(cmd.Process, 200*time.Millisecond) assert.True(t, ok, "child process did not exit gracefully") assert.Error(t, err, "child process should have exited with an error") }() select { case <-done: // continue case <-time.After(200 * time.Millisecond): // Took too long to kill the child process. t.Fatal("Took too long to kill child process") } assert.NoError(t, waitPidDead(pid, 100*time.Millisecond), "error waiting for process to die") }) } } type testProgramKind int const ( goTestProgram testProgramKind = iota nodeTestProgram pythonTestProgram ) func (k testProgramKind) String() string { switch k { case goTestProgram: return "go" case nodeTestProgram: return "node" case pythonTestProgram: return "python" default: return fmt.Sprintf("testProgramKind(%d)", int(k)) } } // From builds a testProgram of this kind // with the given source file. // // Usage: // // goTestProgram.From("main.go") func (k testProgramKind) From(path string) testProgram { return testProgram{ kind: k, src: path, } } // testProgram is a test program inside the testdata directory. type testProgram struct { // kind is the kind of test program. kind testProgramKind // src is the path to the source file // relative to the testdata directory. src string // args specifies additional arguments to pass to the program. args []string } func (p testProgram) Args(args ...string) testProgram { p.args = args return p } // Build builds an exec.Cmd for the test program. // It skips the test if the program runner is not found. func (p testProgram) Build(t *testing.T) (cmd *exec.Cmd) { t.Helper() defer func() { // Make sure that the returned command // is part of the process group. if cmd != nil { RegisterProcessGroup(cmd) } }() src := filepath.Join("testdata", p.src) switch p.kind { case goTestProgram: goBin := lookPathOrSkip(t, "go") bin := filepath.Join(t.TempDir(), "main") if runtime.GOOS == "windows" { bin += ".exe" } buildCmd := exec.Command(goBin, "build", "-o", bin, src) buildOutput := iotest.LogWriterPrefixed(t, "build: ") buildCmd.Stdout = buildOutput buildCmd.Stderr = buildOutput require.NoError(t, buildCmd.Run(), "error building test program") return exec.Command(bin, p.args...) case nodeTestProgram: nodeBin := lookPathOrSkip(t, "node") return exec.Command(nodeBin, append([]string{src}, p.args...)...) case pythonTestProgram: pythonBin := lookPathOrSkip(t, "python") return exec.Command(pythonBin, append([]string{src}, p.args...)...) default: t.Fatalf("unknown test program kind: %v", p.kind) return nil } } func lookPathOrSkip(t *testing.T, name string) string { path, err := exec.LookPath(name) if err != nil { t.Skipf("Skipping test: %q not found: %v", name, err) } return path } // lockedBuffer is a thread-safe bytes.Buffer // that can be used to capture stdout/stderr of a command. type lockedBuffer struct { mu sync.RWMutex b bytes.Buffer } func (b *lockedBuffer) Write(p []byte) (int, error) { b.mu.Lock() defer b.mu.Unlock() return b.b.Write(p) } func (b *lockedBuffer) Len() int { b.mu.RLock() defer b.mu.RUnlock() return b.b.Len() } // Waits until the process with the given pid doesn't exist anymore // or the given timeout has elapsed. // // Returns an error if the timeout has elapsed. func waitPidDead(pid int, timeout time.Duration) error { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() var ( proc ps.Process err error ) for { select { case <-ctx.Done(): var errs []error if proc != nil { errs = append(errs, fmt.Errorf("process %d still exists: %v", pid, proc)) } if err != nil { errs = append(errs, fmt.Errorf("find process: %w", err)) } return fmt.Errorf("waitPidDead %v: %w", pid, errors.Join(errs...)) default: proc, err = ps.FindProcess(pid) if err == nil && proc == nil { return nil } time.Sleep(10 * time.Millisecond) } } }