ensure processes are killed when client disconnects

This commit is contained in:
Michael Stapelberg 2018-07-22 23:04:18 +02:00
parent 05c84e7002
commit a1fd5f6920

20
ssh.go
View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
@ -31,9 +32,12 @@ func handleChannel(newChannel ssh.NewChannel) {
// Sessions have out-of-band requests such as "shell", "pty-req" and "env" // Sessions have out-of-band requests such as "shell", "pty-req" and "env"
go func(channel ssh.Channel, requests <-chan *ssh.Request) { go func(channel ssh.Channel, requests <-chan *ssh.Request) {
ctx, canc := context.WithCancel(context.Background())
defer canc()
s := session{channel: channel} s := session{channel: channel}
for req := range requests { for req := range requests {
if err := s.request(req); err != nil { if err := s.request(ctx, req); err != nil {
log.Printf("request(%q): %v", req.Type, err)
errmsg := []byte(err.Error()) errmsg := []byte(err.Error())
// Append a trailing newline; the error message is // Append a trailing newline; the error message is
// displayed as-is by ssh(1). // displayed as-is by ssh(1).
@ -45,6 +49,7 @@ func handleChannel(newChannel ssh.NewChannel) {
channel.Close() channel.Close()
} }
} }
log.Printf("requests exhausted")
}(channel, requests) }(channel, requests)
} }
@ -93,7 +98,7 @@ func stringFromPayload(payload []byte, offset int) (string, int, error) {
return string(name), offset + 4 + int(namelen), nil return string(name), offset + 4 + int(namelen), nil
} }
func (s *session) request(req *ssh.Request) error { func (s *session) request(ctx context.Context, req *ssh.Request) error {
switch req.Type { switch req.Type {
case "pty-req": case "pty-req":
var err error var err error
@ -152,9 +157,9 @@ func (s *session) request(req *ssh.Request) error {
var cmd *exec.Cmd var cmd *exec.Cmd
if _, err := exec.LookPath("sh"); err == nil { if _, err := exec.LookPath("sh"); err == nil {
cmd = exec.Command("sh", "-c", string(req.Payload[4:])) cmd = exec.CommandContext(ctx, "sh", "-c", string(req.Payload[4:]))
} else { } else {
cmd = exec.Command(cmdline[0], cmdline[1:]...) cmd = exec.CommandContext(ctx, cmdline[0], cmdline[1:]...)
} }
log.Printf("Starting cmd %q", cmd.Args) log.Printf("Starting cmd %q", cmd.Args)
cmd.Env = expandPath(s.env) cmd.Env = expandPath(s.env)
@ -188,15 +193,18 @@ func (s *session) request(req *ssh.Request) error {
stdin.Close() stdin.Close()
}() }()
go func() {
// TODO: correctly pass on the exit code, currently it is always 255
if err := cmd.Wait(); err != nil { if err := cmd.Wait(); err != nil {
return err log.Printf("err: %v", err)
} }
// See https://tools.ietf.org/html/rfc4254#section-6.10 // See https://tools.ietf.org/html/rfc4254#section-6.10
if _, err := s.channel.SendRequest("exit-status", false /* wantReply */, []byte("\x00\x00\x00\x00")); err != nil { if _, err := s.channel.SendRequest("exit-status", false /* wantReply */, []byte("\x00\x00\x00\x00")); err != nil {
return err log.Printf("err2: %v", err)
} }
s.channel.Close() s.channel.Close()
}()
return nil return nil
} }