ensure processes are killed when client disconnects

This commit is contained in:
Michael Stapelberg 2018-07-22 23:04:18 +02:00
parent 05c84e7002
commit a1fd5f6920

32
ssh.go
View File

@ -1,6 +1,7 @@
package main
import (
"context"
"encoding/binary"
"fmt"
"io"
@ -31,9 +32,12 @@ func handleChannel(newChannel ssh.NewChannel) {
// Sessions have out-of-band requests such as "shell", "pty-req" and "env"
go func(channel ssh.Channel, requests <-chan *ssh.Request) {
ctx, canc := context.WithCancel(context.Background())
defer canc()
s := session{channel: channel}
for req := range requests {
if err := s.request(req); err != nil {
if err := s.request(ctx, req); err != nil {
log.Printf("request(%q): %v", req.Type, err)
errmsg := []byte(err.Error())
// Append a trailing newline; the error message is
// displayed as-is by ssh(1).
@ -45,6 +49,7 @@ func handleChannel(newChannel ssh.NewChannel) {
channel.Close()
}
}
log.Printf("requests exhausted")
}(channel, requests)
}
@ -93,7 +98,7 @@ func stringFromPayload(payload []byte, offset int) (string, int, error) {
return string(name), offset + 4 + int(namelen), nil
}
func (s *session) request(req *ssh.Request) error {
func (s *session) request(ctx context.Context, req *ssh.Request) error {
switch req.Type {
case "pty-req":
var err error
@ -152,9 +157,9 @@ func (s *session) request(req *ssh.Request) error {
var cmd *exec.Cmd
if _, err := exec.LookPath("sh"); err == nil {
cmd = exec.Command("sh", "-c", string(req.Payload[4:]))
cmd = exec.CommandContext(ctx, "sh", "-c", string(req.Payload[4:]))
} else {
cmd = exec.Command(cmdline[0], cmdline[1:]...)
cmd = exec.CommandContext(ctx, cmdline[0], cmdline[1:]...)
}
log.Printf("Starting cmd %q", cmd.Args)
cmd.Env = expandPath(s.env)
@ -188,15 +193,18 @@ func (s *session) request(req *ssh.Request) error {
stdin.Close()
}()
if err := cmd.Wait(); err != nil {
return err
}
go func() {
// TODO: correctly pass on the exit code, currently it is always 255
if err := cmd.Wait(); err != nil {
log.Printf("err: %v", err)
}
// See https://tools.ietf.org/html/rfc4254#section-6.10
if _, err := s.channel.SendRequest("exit-status", false /* wantReply */, []byte("\x00\x00\x00\x00")); err != nil {
return err
}
s.channel.Close()
// See https://tools.ietf.org/html/rfc4254#section-6.10
if _, err := s.channel.SendRequest("exit-status", false /* wantReply */, []byte("\x00\x00\x00\x00")); err != nil {
log.Printf("err2: %v", err)
}
s.channel.Close()
}()
return nil
}