breakglass/scp.go
Michael Stapelberg 48c5124500 unpack tar files copied via sftp subsystem, too (not just older scp)
For compatibility with OpenSSH ≥ 9
2022-04-17 15:32:45 +02:00

134 lines
2.8 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"archive/tar"
"flag"
"fmt"
"io"
"log"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/google/renameio/v2"
"golang.org/x/crypto/ssh"
)
type countingWriter int64
func (cw *countingWriter) Write(p []byte) (n int, err error) {
*cw += countingWriter(len(p))
return len(p), nil
}
func scpSink(channel ssh.Channel, req *ssh.Request, cmdline []string) error {
scpFlags := flag.NewFlagSet("scp", flag.ContinueOnError)
sink := scpFlags.Bool("t", false, "sink (to)")
if err := scpFlags.Parse(cmdline[1:]); err != nil {
return err
}
if !*sink {
return fmt.Errorf("expected -t")
}
// Tell the remote end were ready to receive data.
if _, err := channel.Write([]byte{0x00}); err != nil {
return err
}
buf := make([]byte, 1024)
for {
n, err := channel.Read(buf)
if err == io.EOF {
break
}
if err != nil {
return err
}
msg := buf[:n]
// Acknowledge receipt of the control message
if _, err := channel.Write([]byte{0x00}); err != nil {
return err
}
if msg[0] == 'C' {
msgstr := strings.TrimSpace(string(msg))
parts := strings.Split(msgstr, " ")
if got, want := len(parts), 3; got != want {
return fmt.Errorf("invalid number of space-separated tokens in control message %q: got %d, want %d", msgstr, got, want)
}
size, err := strconv.ParseInt(parts[1], 0, 64)
if err != nil {
return err
}
// Retrieve file contents
var cw countingWriter
if err := unpackTar(io.TeeReader(channel, &cw)); err != nil {
return err
}
if rest := size - int64(cw); rest > 0 {
buf := make([]byte, rest)
if _, err := channel.Read(buf); err != nil {
return err
}
}
// Read status byte after transfer
buf := make([]byte, 1)
if _, err := channel.Read(buf); err != nil {
return err
}
// Acknowledge file transfer
if _, err := channel.Write([]byte{0x00}); err != nil {
return err
}
}
}
exitStatus := make([]byte, 4)
exitStatus[3] = 0
if _, err := channel.SendRequest("exit-status", false, exitStatus); err != nil {
return err
}
channel.Close()
req.Reply(true, nil)
return nil
}
func unpackTar(r io.Reader) error {
tr := tar.NewReader(r)
for {
h, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return err
}
log.Printf("extracting %q", h.Name)
if err := os.MkdirAll(filepath.Dir(h.Name), 0700); err != nil {
return err
}
if strings.HasSuffix(h.Name, "/") {
continue // directory, dont try to OpenFile() it
}
mode := h.FileInfo().Mode() & os.ModePerm
out, err := renameio.NewPendingFile(h.Name, renameio.WithStaticPermissions(mode))
if err != nil {
return err
}
if _, err := io.Copy(out, tr); err != nil {
return err
}
if err := out.CloseAtomicallyReplace(); err != nil {
return err
}
}
return nil
}