diff --git a/ssh.go b/ssh.go index 6b19f61..1edcf0a 100644 --- a/ssh.go +++ b/ssh.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/binary" "fmt" "io" @@ -31,9 +32,12 @@ func handleChannel(newChannel ssh.NewChannel) { // Sessions have out-of-band requests such as "shell", "pty-req" and "env" go func(channel ssh.Channel, requests <-chan *ssh.Request) { + ctx, canc := context.WithCancel(context.Background()) + defer canc() s := session{channel: channel} for req := range requests { - if err := s.request(req); err != nil { + if err := s.request(ctx, req); err != nil { + log.Printf("request(%q): %v", req.Type, err) errmsg := []byte(err.Error()) // Append a trailing newline; the error message is // displayed as-is by ssh(1). @@ -45,6 +49,7 @@ func handleChannel(newChannel ssh.NewChannel) { channel.Close() } } + log.Printf("requests exhausted") }(channel, requests) } @@ -93,7 +98,7 @@ func stringFromPayload(payload []byte, offset int) (string, int, error) { return string(name), offset + 4 + int(namelen), nil } -func (s *session) request(req *ssh.Request) error { +func (s *session) request(ctx context.Context, req *ssh.Request) error { switch req.Type { case "pty-req": var err error @@ -152,9 +157,9 @@ func (s *session) request(req *ssh.Request) error { var cmd *exec.Cmd if _, err := exec.LookPath("sh"); err == nil { - cmd = exec.Command("sh", "-c", string(req.Payload[4:])) + cmd = exec.CommandContext(ctx, "sh", "-c", string(req.Payload[4:])) } else { - cmd = exec.Command(cmdline[0], cmdline[1:]...) + cmd = exec.CommandContext(ctx, cmdline[0], cmdline[1:]...) } log.Printf("Starting cmd %q", cmd.Args) cmd.Env = expandPath(s.env) @@ -188,15 +193,18 @@ func (s *session) request(req *ssh.Request) error { stdin.Close() }() - if err := cmd.Wait(); err != nil { - return err - } + go func() { + // TODO: correctly pass on the exit code, currently it is always 255 + if err := cmd.Wait(); err != nil { + log.Printf("err: %v", err) + } - // See https://tools.ietf.org/html/rfc4254#section-6.10 - if _, err := s.channel.SendRequest("exit-status", false /* wantReply */, []byte("\x00\x00\x00\x00")); err != nil { - return err - } - s.channel.Close() + // See https://tools.ietf.org/html/rfc4254#section-6.10 + if _, err := s.channel.SendRequest("exit-status", false /* wantReply */, []byte("\x00\x00\x00\x00")); err != nil { + log.Printf("err2: %v", err) + } + s.channel.Close() + }() return nil }