ensure processes are killed when client disconnects
This commit is contained in:
parent
05c84e7002
commit
a1fd5f6920
32
ssh.go
32
ssh.go
@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@ -31,9 +32,12 @@ func handleChannel(newChannel ssh.NewChannel) {
|
|||||||
|
|
||||||
// Sessions have out-of-band requests such as "shell", "pty-req" and "env"
|
// Sessions have out-of-band requests such as "shell", "pty-req" and "env"
|
||||||
go func(channel ssh.Channel, requests <-chan *ssh.Request) {
|
go func(channel ssh.Channel, requests <-chan *ssh.Request) {
|
||||||
|
ctx, canc := context.WithCancel(context.Background())
|
||||||
|
defer canc()
|
||||||
s := session{channel: channel}
|
s := session{channel: channel}
|
||||||
for req := range requests {
|
for req := range requests {
|
||||||
if err := s.request(req); err != nil {
|
if err := s.request(ctx, req); err != nil {
|
||||||
|
log.Printf("request(%q): %v", req.Type, err)
|
||||||
errmsg := []byte(err.Error())
|
errmsg := []byte(err.Error())
|
||||||
// Append a trailing newline; the error message is
|
// Append a trailing newline; the error message is
|
||||||
// displayed as-is by ssh(1).
|
// displayed as-is by ssh(1).
|
||||||
@ -45,6 +49,7 @@ func handleChannel(newChannel ssh.NewChannel) {
|
|||||||
channel.Close()
|
channel.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
log.Printf("requests exhausted")
|
||||||
}(channel, requests)
|
}(channel, requests)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,7 +98,7 @@ func stringFromPayload(payload []byte, offset int) (string, int, error) {
|
|||||||
return string(name), offset + 4 + int(namelen), nil
|
return string(name), offset + 4 + int(namelen), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) request(req *ssh.Request) error {
|
func (s *session) request(ctx context.Context, req *ssh.Request) error {
|
||||||
switch req.Type {
|
switch req.Type {
|
||||||
case "pty-req":
|
case "pty-req":
|
||||||
var err error
|
var err error
|
||||||
@ -152,9 +157,9 @@ func (s *session) request(req *ssh.Request) error {
|
|||||||
|
|
||||||
var cmd *exec.Cmd
|
var cmd *exec.Cmd
|
||||||
if _, err := exec.LookPath("sh"); err == nil {
|
if _, err := exec.LookPath("sh"); err == nil {
|
||||||
cmd = exec.Command("sh", "-c", string(req.Payload[4:]))
|
cmd = exec.CommandContext(ctx, "sh", "-c", string(req.Payload[4:]))
|
||||||
} else {
|
} else {
|
||||||
cmd = exec.Command(cmdline[0], cmdline[1:]...)
|
cmd = exec.CommandContext(ctx, cmdline[0], cmdline[1:]...)
|
||||||
}
|
}
|
||||||
log.Printf("Starting cmd %q", cmd.Args)
|
log.Printf("Starting cmd %q", cmd.Args)
|
||||||
cmd.Env = expandPath(s.env)
|
cmd.Env = expandPath(s.env)
|
||||||
@ -188,15 +193,18 @@ func (s *session) request(req *ssh.Request) error {
|
|||||||
stdin.Close()
|
stdin.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := cmd.Wait(); err != nil {
|
go func() {
|
||||||
return err
|
// TODO: correctly pass on the exit code, currently it is always 255
|
||||||
}
|
if err := cmd.Wait(); err != nil {
|
||||||
|
log.Printf("err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// See https://tools.ietf.org/html/rfc4254#section-6.10
|
// See https://tools.ietf.org/html/rfc4254#section-6.10
|
||||||
if _, err := s.channel.SendRequest("exit-status", false /* wantReply */, []byte("\x00\x00\x00\x00")); err != nil {
|
if _, err := s.channel.SendRequest("exit-status", false /* wantReply */, []byte("\x00\x00\x00\x00")); err != nil {
|
||||||
return err
|
log.Printf("err2: %v", err)
|
||||||
}
|
}
|
||||||
s.channel.Close()
|
s.channel.Close()
|
||||||
|
}()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user