fix subsystem invocation: send exit code afterwards

This fixes scp(1) with OpenSSH ≥ 9.
This commit is contained in:
Michael Stapelberg 2022-04-17 15:23:09 +02:00
parent 097a6f87d6
commit 7dbbe9b4b3

36
ssh.go
View File

@ -2,7 +2,6 @@ package main
import (
"context"
"encoding/binary"
"fmt"
"io"
"log"
@ -210,6 +209,11 @@ type subsystem struct {
SubsystemName string
}
// exitStatus is a message for returning exit status as specified in RFC4254, Section 6.10
type exitStatus struct {
Status uint32
}
func findShell() string {
if path, err := exec.LookPath("sh"); err == nil {
return path
@ -270,23 +274,27 @@ func (s *session) request(ctx context.Context, req *ssh.Request) error {
log.Printf("starting SFTP subsystem")
req.Reply(true, nil)
srv, err := sftp.NewServer(s.channel, sftp.WithDebug(os.Stderr))
if err != nil {
return err
}
go func() {
err := srv.Serve()
if err != nil {
log.Printf("(sftp.Server).Serve(): %v", err)
if err == io.EOF {
srv.Close()
log.Printf("sftp client exited session")
}
exitCode := uint32(0)
if err := srv.Serve(); err != nil {
log.Printf("(sftp.Server).Serve(): %v", err)
if err == io.EOF {
defer srv.Close()
log.Printf("sftp client exited session")
} else {
exitCode = 1
}
}()
}
req.Reply(true, nil)
// See https://tools.ietf.org/html/rfc4254#section-6.10
_, err = s.channel.SendRequest("exit-status", false /* wantReply */, ssh.Marshal(exitStatus{exitCode}))
return err
case "shell":
req.Payload = []byte("\x00\x00\x00\x02sh")
@ -353,13 +361,13 @@ func (s *session) request(ctx context.Context, req *ssh.Request) error {
if err := cmd.Wait(); err != nil {
log.Printf("err: %v", err)
}
status := make([]byte, 4)
var status exitStatus
if ws, ok := cmd.ProcessState.Sys().(syscall.WaitStatus); ok {
binary.BigEndian.PutUint32(status, uint32(ws.ExitStatus()))
status.Status = uint32(ws.ExitStatus())
}
// See https://tools.ietf.org/html/rfc4254#section-6.10
if _, err := s.channel.SendRequest("exit-status", false /* wantReply */, status); err != nil {
if _, err := s.channel.SendRequest("exit-status", false /* wantReply */, ssh.Marshal(status)); err != nil {
log.Printf("err2: %v", err)
}
s.channel.Close()