ensure processes are killed when client disconnects
This commit is contained in:
parent
05c84e7002
commit
a1fd5f6920
32
ssh.go
32
ssh.go
@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -31,9 +32,12 @@ func handleChannel(newChannel ssh.NewChannel) {
|
||||
|
||||
// Sessions have out-of-band requests such as "shell", "pty-req" and "env"
|
||||
go func(channel ssh.Channel, requests <-chan *ssh.Request) {
|
||||
ctx, canc := context.WithCancel(context.Background())
|
||||
defer canc()
|
||||
s := session{channel: channel}
|
||||
for req := range requests {
|
||||
if err := s.request(req); err != nil {
|
||||
if err := s.request(ctx, req); err != nil {
|
||||
log.Printf("request(%q): %v", req.Type, err)
|
||||
errmsg := []byte(err.Error())
|
||||
// Append a trailing newline; the error message is
|
||||
// displayed as-is by ssh(1).
|
||||
@ -45,6 +49,7 @@ func handleChannel(newChannel ssh.NewChannel) {
|
||||
channel.Close()
|
||||
}
|
||||
}
|
||||
log.Printf("requests exhausted")
|
||||
}(channel, requests)
|
||||
}
|
||||
|
||||
@ -93,7 +98,7 @@ func stringFromPayload(payload []byte, offset int) (string, int, error) {
|
||||
return string(name), offset + 4 + int(namelen), nil
|
||||
}
|
||||
|
||||
func (s *session) request(req *ssh.Request) error {
|
||||
func (s *session) request(ctx context.Context, req *ssh.Request) error {
|
||||
switch req.Type {
|
||||
case "pty-req":
|
||||
var err error
|
||||
@ -152,9 +157,9 @@ func (s *session) request(req *ssh.Request) error {
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if _, err := exec.LookPath("sh"); err == nil {
|
||||
cmd = exec.Command("sh", "-c", string(req.Payload[4:]))
|
||||
cmd = exec.CommandContext(ctx, "sh", "-c", string(req.Payload[4:]))
|
||||
} else {
|
||||
cmd = exec.Command(cmdline[0], cmdline[1:]...)
|
||||
cmd = exec.CommandContext(ctx, cmdline[0], cmdline[1:]...)
|
||||
}
|
||||
log.Printf("Starting cmd %q", cmd.Args)
|
||||
cmd.Env = expandPath(s.env)
|
||||
@ -188,15 +193,18 @@ func (s *session) request(req *ssh.Request) error {
|
||||
stdin.Close()
|
||||
}()
|
||||
|
||||
if err := cmd.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
// TODO: correctly pass on the exit code, currently it is always 255
|
||||
if err := cmd.Wait(); err != nil {
|
||||
log.Printf("err: %v", err)
|
||||
}
|
||||
|
||||
// See https://tools.ietf.org/html/rfc4254#section-6.10
|
||||
if _, err := s.channel.SendRequest("exit-status", false /* wantReply */, []byte("\x00\x00\x00\x00")); err != nil {
|
||||
return err
|
||||
}
|
||||
s.channel.Close()
|
||||
// See https://tools.ietf.org/html/rfc4254#section-6.10
|
||||
if _, err := s.channel.SendRequest("exit-status", false /* wantReply */, []byte("\x00\x00\x00\x00")); err != nil {
|
||||
log.Printf("err2: %v", err)
|
||||
}
|
||||
s.channel.Close()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user