Skip to content

Commit 3ac3c9d

Browse files
guilhas07a-h
andauthoredSep 2, 2024··
fix: send SIGTERM signal to --cmd instead of SIGKILL (#687)
Co-authored-by: Adrian Hesketh <adrianhesketh@hushmail.com> Co-authored-by: Adrian Hesketh <a-h@users.noreply.github.com>
1 parent c7c32aa commit 3ac3c9d

File tree

4 files changed

+209
-11
lines changed

4 files changed

+209
-11
lines changed
 

‎cmd/templ/generatecmd/run/run_test.go

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package run_test
2+
3+
import (
4+
"context"
5+
"embed"
6+
"io"
7+
"net/http"
8+
"os"
9+
"path/filepath"
10+
"syscall"
11+
"testing"
12+
"time"
13+
14+
"github.com/a-h/templ/cmd/templ/generatecmd/run"
15+
)
16+
17+
//go:embed testprogram/*
18+
var testprogram embed.FS
19+
20+
func TestGoRun(t *testing.T) {
21+
if testing.Short() {
22+
t.Skip("Skipping test in short mode.")
23+
}
24+
25+
// Copy testprogram to a temporary directory.
26+
dir, err := os.MkdirTemp("", "testprogram")
27+
if err != nil {
28+
t.Fatalf("failed to make test dir: %v", err)
29+
}
30+
files, err := testprogram.ReadDir("testprogram")
31+
if err != nil {
32+
t.Fatalf("failed to read embedded dir: %v", err)
33+
}
34+
for _, file := range files {
35+
srcFileName := "testprogram/" + file.Name()
36+
srcData, err := testprogram.ReadFile(srcFileName)
37+
if err != nil {
38+
t.Fatalf("failed to read src file %q: %v", srcFileName, err)
39+
}
40+
tgtFileName := filepath.Join(dir, file.Name())
41+
tgtFile, err := os.Create(tgtFileName)
42+
if err != nil {
43+
t.Fatalf("failed to create tgt file %q: %v", tgtFileName, err)
44+
}
45+
defer tgtFile.Close()
46+
if _, err := tgtFile.Write(srcData); err != nil {
47+
t.Fatalf("failed to write to tgt file %q: %v", tgtFileName, err)
48+
}
49+
}
50+
// Rename the go.mod.embed file to go.mod.
51+
if err := os.Rename(filepath.Join(dir, "go.mod.embed"), filepath.Join(dir, "go.mod")); err != nil {
52+
t.Fatalf("failed to rename go.mod.embed: %v", err)
53+
}
54+
55+
tests := []struct {
56+
name string
57+
cmd string
58+
}{
59+
{
60+
name: "Well behaved programs get shut down",
61+
cmd: "go run .",
62+
},
63+
{
64+
name: "Badly behaved programs get shut down",
65+
cmd: "go run . -badly-behaved",
66+
},
67+
}
68+
for _, tt := range tests {
69+
t.Run(tt.name, func(t *testing.T) {
70+
ctx := context.Background()
71+
cmd, err := run.Run(ctx, dir, tt.cmd)
72+
if err != nil {
73+
t.Fatalf("failed to run program: %v", err)
74+
}
75+
76+
time.Sleep(1 * time.Second)
77+
78+
pid := cmd.Process.Pid
79+
80+
if err := run.KillAll(); err != nil {
81+
t.Fatalf("failed to kill all: %v", err)
82+
}
83+
84+
// Check the parent process is no longer running.
85+
if err := cmd.Process.Signal(os.Signal(syscall.Signal(0))); err == nil {
86+
t.Fatalf("process %d is still running", pid)
87+
}
88+
// Check that the child was stopped.
89+
body, err := readResponse("http://localhost:7777")
90+
if err == nil {
91+
t.Fatalf("child process is still running: %s", body)
92+
}
93+
})
94+
}
95+
}
96+
97+
func readResponse(url string) (body string, err error) {
98+
resp, err := http.Get(url)
99+
if err != nil {
100+
return body, err
101+
}
102+
defer resp.Body.Close()
103+
b, err := io.ReadAll(resp.Body)
104+
if err != nil {
105+
return body, err
106+
}
107+
return string(b), nil
108+
}

‎cmd/templ/generatecmd/run/run_unix.go

+35-11
Original file line numberDiff line numberDiff line change
@@ -4,41 +4,63 @@ package run
44

55
import (
66
"context"
7+
"errors"
8+
"fmt"
79
"os"
810
"os/exec"
911
"strings"
1012
"sync"
1113
"syscall"
14+
"time"
1215
)
1316

14-
var m = &sync.Mutex{}
15-
var running = map[string]*exec.Cmd{}
17+
var (
18+
m = &sync.Mutex{}
19+
running = map[string]*exec.Cmd{}
20+
)
1621

1722
func KillAll() (err error) {
1823
m.Lock()
1924
defer m.Unlock()
25+
var errs []error
2026
for _, cmd := range running {
21-
err := syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
22-
if err != nil {
23-
return err
27+
if err := kill(cmd); err != nil {
28+
errs = append(errs, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err))
2429
}
2530
}
2631
running = map[string]*exec.Cmd{}
27-
return
32+
return errors.Join(errs...)
33+
}
34+
35+
func kill(cmd *exec.Cmd) (err error) {
36+
errs := make([]error, 4)
37+
errs[0] = ignoreExited(cmd.Process.Signal(syscall.SIGINT))
38+
errs[1] = ignoreExited(cmd.Process.Signal(syscall.SIGTERM))
39+
errs[2] = ignoreExited(cmd.Wait())
40+
errs[3] = ignoreExited(syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL))
41+
return errors.Join(errs...)
2842
}
2943

30-
func Stop(cmd *exec.Cmd) (err error) {
31-
return syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
44+
func ignoreExited(err error) error {
45+
if errors.Is(err, syscall.ESRCH) {
46+
return nil
47+
}
48+
// Ignore *exec.ExitError
49+
if _, ok := err.(*exec.ExitError); ok {
50+
return nil
51+
}
52+
return err
3253
}
3354

3455
func Run(ctx context.Context, workingDir, input string) (cmd *exec.Cmd, err error) {
3556
m.Lock()
3657
defer m.Unlock()
3758
cmd, ok := running[input]
3859
if ok {
39-
if err = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL); err != nil {
40-
return cmd, err
60+
if err := kill(cmd); err != nil {
61+
return cmd, fmt.Errorf("failed to kill process %d: %w", cmd.Process.Pid, err)
4162
}
63+
4264
delete(running, input)
4365
}
4466
parts := strings.Fields(input)
@@ -48,7 +70,9 @@ func Run(ctx context.Context, workingDir, input string) (cmd *exec.Cmd, err erro
4870
args = append(args, parts[1:]...)
4971
}
5072

51-
cmd = exec.Command(executable, args...)
73+
cmd = exec.CommandContext(ctx, executable, args...)
74+
// Wait for the process to finish gracefully before termination.
75+
cmd.WaitDelay = time.Second * 3
5276
cmd.Env = os.Environ()
5377
cmd.Dir = workingDir
5478
cmd.Stdout = os.Stdout
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
module testprogram
2+
3+
go 1.22.6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package main
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
"net/http"
7+
"os"
8+
"os/signal"
9+
"syscall"
10+
"time"
11+
)
12+
13+
// This is a test program. It is used only to test the behaviour of the run package.
14+
// The run package is supposed to be able to run and stop programs. Those programs may start
15+
// child processes, which should also be stopped when the parent program is stopped.
16+
17+
// For example, running `go run .` will compile an executable and run it.
18+
19+
// So, this program does nothing. It just waits for a signal to stop.
20+
21+
// In "Well behaved" mode, the program will stop when it receives a signal.
22+
// In "Badly behaved" mode, the program will ignore the signal and continue running.
23+
24+
// The run package should be able to stop the program in both cases.
25+
26+
var badlyBehavedFlag = flag.Bool("badly-behaved", false, "If set, the program will ignore the stop signal and continue running.")
27+
28+
func main() {
29+
flag.Parse()
30+
31+
mode := "Well behaved"
32+
if *badlyBehavedFlag {
33+
mode = "Badly behaved"
34+
}
35+
fmt.Printf("%s process %d started.\n", mode, os.Getpid())
36+
37+
// Start a web server on a known port so that we can check that this process is
38+
// not running, when it's been started as a child process, and we don't know
39+
// its pid.
40+
go func() {
41+
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
42+
fmt.Fprintf(w, "%d", os.Getpid())
43+
})
44+
err := http.ListenAndServe("127.0.0.1:7777", nil)
45+
if err != nil {
46+
fmt.Printf("Error running web server: %v\n", err)
47+
}
48+
}()
49+
50+
sigs := make(chan os.Signal, 1)
51+
if !*badlyBehavedFlag {
52+
signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
53+
}
54+
for {
55+
select {
56+
case <-sigs:
57+
fmt.Printf("Process %d received signal. Stopping.\n", os.Getpid())
58+
return
59+
case <-time.After(1 * time.Second):
60+
fmt.Printf("Process %d still running...\n", os.Getpid())
61+
}
62+
}
63+
}

0 commit comments

Comments
 (0)
Please sign in to comment.