diff --git a/cmd/root.go b/cmd/root.go index 2900ac5..50a2155 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -391,6 +391,8 @@ func newRunCommand(ctx context.Context, input *Input) func(*cobra.Command, []str } if ok, _ := cmd.Flags().GetBool("bug-report"); ok { + ctx, cancel := common.EarlyCancelContext(ctx) + defer cancel() return bugReport(ctx, cmd.Version) } if ok, _ := cmd.Flags().GetBool("man-page"); ok { @@ -430,6 +432,8 @@ func newRunCommand(ctx context.Context, input *Input) func(*cobra.Command, []str _ = readEnvsEx(input.Secretfile(), secrets, true) if _, hasGitHubToken := secrets["GITHUB_TOKEN"]; !hasGitHubToken { + ctx, cancel := common.EarlyCancelContext(ctx) + defer cancel() secrets["GITHUB_TOKEN"], _ = gh.GetToken(ctx, "") } @@ -772,10 +776,13 @@ func watchAndRun(ctx context.Context, fn common.Executor) error { return err } + earlyCancelCtx, cancel := common.EarlyCancelContext(ctx) + defer cancel() + for folderWatcher.IsRunning() { log.Debugf("Watching %s for changes", dir) select { - case <-ctx.Done(): + case <-earlyCancelCtx.Done(): return nil case changes := <-folderWatcher.ChangeDetails(): log.Debugf("%s", changes.String()) diff --git a/main.go b/main.go index 37b0fec..1567b96 100644 --- a/main.go +++ b/main.go @@ -1,36 +1,18 @@ package main import ( - "context" _ "embed" - "os" - "os/signal" - "syscall" "github.com/nektos/act/cmd" + "github.com/nektos/act/pkg/common" ) //go:embed VERSION var version string func main() { - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) - - // trap Ctrl+C and call cancel on the context - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - defer func() { - signal.Stop(c) - cancel() - }() - go func() { - select { - case <-c: - cancel() - case <-ctx.Done(): - } - }() + ctx, cancel := common.CreateGracefulJobCancellationContext() + defer cancel() // run the command cmd.Execute(ctx, version) diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..6e2f4bf --- /dev/null +++ b/main_test.go @@ -0,0 +1,11 @@ +package main + +import ( + "os" + "testing" +) + +func TestMain(_ *testing.T) { + os.Args = []string{"act", "--help"} + main() +} diff --git a/pkg/common/context.go b/pkg/common/context.go new file mode 100644 index 0000000..0a53489 --- /dev/null +++ b/pkg/common/context.go @@ -0,0 +1,45 @@ +package common + +import ( + "context" + "os" + "os/signal" + "syscall" +) + +func createGracefulJobCancellationContext() (context.Context, func(), chan os.Signal) { + ctx := context.Background() + ctx, forceCancel := context.WithCancel(ctx) + cancelCtx, cancel := context.WithCancel(ctx) + ctx = WithJobCancelContext(ctx, cancelCtx) + + // trap Ctrl+C and call cancel on the context + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + go func() { + select { + case sig := <-c: + if sig == os.Interrupt { + cancel() + select { + case <-c: + forceCancel() + case <-ctx.Done(): + } + } else { + forceCancel() + } + case <-ctx.Done(): + } + }() + return ctx, func() { + signal.Stop(c) + forceCancel() + cancel() + }, c +} + +func CreateGracefulJobCancellationContext() (context.Context, func()) { + ctx, cancel, _ := createGracefulJobCancellationContext() + return ctx, cancel +} diff --git a/pkg/common/context_test.go b/pkg/common/context_test.go new file mode 100644 index 0000000..e821f83 --- /dev/null +++ b/pkg/common/context_test.go @@ -0,0 +1,98 @@ +package common + +import ( + "context" + "os" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestGracefulJobCancellationViaSigint(t *testing.T) { + ctx, cancel, channel := createGracefulJobCancellationContext() + defer cancel() + assert.NotNil(t, ctx) + assert.NotNil(t, cancel) + assert.NotNil(t, channel) + cancelCtx := JobCancelContext(ctx) + assert.NotNil(t, cancelCtx) + assert.NoError(t, ctx.Err()) + assert.NoError(t, cancelCtx.Err()) + channel <- os.Interrupt + select { + case <-time.After(1 * time.Second): + t.Fatal("context not canceled") + case <-cancelCtx.Done(): + case <-ctx.Done(): + } + if assert.Error(t, cancelCtx.Err(), "context canceled") { + assert.Equal(t, context.Canceled, cancelCtx.Err()) + } + assert.NoError(t, ctx.Err()) + channel <- os.Interrupt + select { + case <-time.After(1 * time.Second): + t.Fatal("context not canceled") + case <-ctx.Done(): + } + if assert.Error(t, ctx.Err(), "context canceled") { + assert.Equal(t, context.Canceled, ctx.Err()) + } +} + +func TestForceCancellationViaSigterm(t *testing.T) { + ctx, cancel, channel := createGracefulJobCancellationContext() + defer cancel() + assert.NotNil(t, ctx) + assert.NotNil(t, cancel) + assert.NotNil(t, channel) + cancelCtx := JobCancelContext(ctx) + assert.NotNil(t, cancelCtx) + assert.NoError(t, ctx.Err()) + assert.NoError(t, cancelCtx.Err()) + channel <- syscall.SIGTERM + select { + case <-time.After(1 * time.Second): + t.Fatal("context not canceled") + case <-cancelCtx.Done(): + } + select { + case <-time.After(1 * time.Second): + t.Fatal("context not canceled") + case <-ctx.Done(): + } + if assert.Error(t, ctx.Err(), "context canceled") { + assert.Equal(t, context.Canceled, ctx.Err()) + } + if assert.Error(t, cancelCtx.Err(), "context canceled") { + assert.Equal(t, context.Canceled, cancelCtx.Err()) + } +} + +func TestCreateGracefulJobCancellationContext(t *testing.T) { + ctx, cancel := CreateGracefulJobCancellationContext() + defer cancel() + assert.NotNil(t, ctx) + assert.NotNil(t, cancel) + cancelCtx := JobCancelContext(ctx) + assert.NotNil(t, cancelCtx) + assert.NoError(t, cancelCtx.Err()) +} + +func TestCreateGracefulJobCancellationContextCancelFunc(t *testing.T) { + ctx, cancel := CreateGracefulJobCancellationContext() + assert.NotNil(t, ctx) + assert.NotNil(t, cancel) + cancelCtx := JobCancelContext(ctx) + assert.NotNil(t, cancelCtx) + assert.NoError(t, cancelCtx.Err()) + cancel() + if assert.Error(t, ctx.Err(), "context canceled") { + assert.Equal(t, context.Canceled, ctx.Err()) + } + if assert.Error(t, cancelCtx.Err(), "context canceled") { + assert.Equal(t, context.Canceled, cancelCtx.Err()) + } +} diff --git a/pkg/common/job_error.go b/pkg/common/job_error.go index 334c6ca..3eb2128 100644 --- a/pkg/common/job_error.go +++ b/pkg/common/job_error.go @@ -8,6 +8,10 @@ type jobErrorContextKey string const jobErrorContextKeyVal = jobErrorContextKey("job.error") +type jobCancelCtx string + +const JobCancelCtxVal = jobCancelCtx("job.cancel") + // JobError returns the job error for current context if any func JobError(ctx context.Context) error { val := ctx.Value(jobErrorContextKeyVal) @@ -28,3 +32,35 @@ func WithJobErrorContainer(ctx context.Context) context.Context { container := map[string]error{} return context.WithValue(ctx, jobErrorContextKeyVal, container) } + +func WithJobCancelContext(ctx context.Context, cancelContext context.Context) context.Context { + return context.WithValue(ctx, JobCancelCtxVal, cancelContext) +} + +func JobCancelContext(ctx context.Context) context.Context { + val := ctx.Value(JobCancelCtxVal) + if val != nil { + if container, ok := val.(context.Context); ok { + return container + } + } + return nil +} + +// EarlyCancelContext returns a new context based on ctx that is canceled when the first of the provided contexts is canceled. +func EarlyCancelContext(ctx context.Context) (context.Context, context.CancelFunc) { + val := JobCancelContext(ctx) + if val != nil { + context, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + select { + case <-context.Done(): + case <-ctx.Done(): + case <-val.Done(): + } + }() + return context, cancel + } + return ctx, func() {} +} diff --git a/pkg/runner/run_context.go b/pkg/runner/run_context.go index c609175..b6b6a48 100644 --- a/pkg/runner/run_context.go +++ b/pkg/runner/run_context.go @@ -51,6 +51,7 @@ type RunContext struct { Masks []string cleanUpJobContainer common.Executor caller *caller // job calling this RunContext (reusable workflows) + Cancelled bool nodeToolFullPath string } @@ -435,6 +436,8 @@ func (rc *RunContext) execJobContainer(cmd []string, env map[string]string, user func (rc *RunContext) InitializeNodeTool() common.Executor { return func(ctx context.Context) error { + ctx, cancel := common.EarlyCancelContext(ctx) + defer cancel() rc.GetNodeToolFullPath(ctx) return nil } @@ -651,6 +654,8 @@ func (rc *RunContext) interpolateOutputs() common.Executor { func (rc *RunContext) startContainer() common.Executor { return func(ctx context.Context) error { + ctx, cancel := common.EarlyCancelContext(ctx) + defer cancel() if rc.IsHostEnv(ctx) { return rc.startHostEnvironment()(ctx) } @@ -845,10 +850,14 @@ func trimToLen(s string, l int) string { func (rc *RunContext) getJobContext() *model.JobContext { jobStatus := "success" - for _, stepStatus := range rc.StepResults { - if stepStatus.Conclusion == model.StepStatusFailure { - jobStatus = "failure" - break + if rc.Cancelled { + jobStatus = "cancelled" + } else { + for _, stepStatus := range rc.StepResults { + if stepStatus.Conclusion == model.StepStatusFailure { + jobStatus = "failure" + break + } } } return &model.JobContext{ diff --git a/pkg/runner/step.go b/pkg/runner/step.go index 5ed2108..a70dacc 100644 --- a/pkg/runner/step.go +++ b/pkg/runner/step.go @@ -85,6 +85,9 @@ func runStepExecutor(step step, stage stepStage, executor common.Executor) commo return err } + cctx := common.JobCancelContext(ctx) + rc.Cancelled = cctx != nil && cctx.Err() != nil + runStep, err := isStepEnabled(ctx, ifExpression, step, stage) if err != nil { stepResult.Conclusion = model.StepStatusFailure @@ -140,10 +143,14 @@ func runStepExecutor(step step, stage stepStage, executor common.Executor) commo Mode: 0o666, })(ctx) - timeoutctx, cancelTimeOut := evaluateStepTimeout(ctx, rc.ExprEval, stepModel) + stepCtx, cancelStepCtx := context.WithCancel(ctx) + defer cancelStepCtx() + var cancelTimeOut context.CancelFunc + stepCtx, cancelTimeOut = evaluateStepTimeout(stepCtx, rc.ExprEval, stepModel) defer cancelTimeOut() + monitorJobCancellation(ctx, stepCtx, cctx, rc, logger, ifExpression, step, stage, cancelStepCtx) startTime := time.Now() - err = executor(timeoutctx) + err = executor(stepCtx) executionTime := time.Since(startTime) if err == nil { @@ -192,6 +199,24 @@ func runStepExecutor(step step, stage stepStage, executor common.Executor) commo } } +func monitorJobCancellation(ctx context.Context, stepCtx context.Context, jobCancellationCtx context.Context, rc *RunContext, logger logrus.FieldLogger, ifExpression string, step step, stage stepStage, cancelStepCtx context.CancelFunc) { + if !rc.Cancelled && jobCancellationCtx != nil { + go func() { + select { + case <-jobCancellationCtx.Done(): + rc.Cancelled = true + logger.Infof("Reevaluate condition %v due to cancellation", ifExpression) + keepStepRunning, err := isStepEnabled(ctx, ifExpression, step, stage) + logger.Infof("Result condition keepStepRunning=%v", keepStepRunning) + if !keepStepRunning || err != nil { + cancelStepCtx() + } + case <-stepCtx.Done(): + } + }() + } +} + func evaluateStepTimeout(ctx context.Context, exprEval ExpressionEvaluator, stepModel *model.Step) (context.Context, context.CancelFunc) { timeout := exprEval.Interpolate(ctx, stepModel.TimeoutMinutes) if timeout != "" {