diff --git a/breakglass.go b/breakglass.go index c612bc0..c7681f7 100644 --- a/breakglass.go +++ b/breakglass.go @@ -10,6 +10,7 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" + "errors" "flag" "fmt" "io/ioutil" @@ -17,6 +18,7 @@ import ( "net" "net/http" "os" + "path" "strings" "syscall" @@ -30,6 +32,10 @@ var ( "/perm/breakglass.authorized_keys", "path to an OpenSSH authorized_keys file; if the value is 'ec2', fetch the SSH key(s) from the AWS IMDSv2 metadata") + authorizedUserCAPath = flag.String("authorized_ca", + "/perm/breakglass.authorized_user_ca", + "path to an OpenSSH TrustedUserCAKeys file; note the certificate must list ':gokrazy:' as a valid principal") + hostKeyPath = flag.String("host_key", "/perm/breakglass.host_key", "path to a PEM-encoded RSA, DSA or ECDSA private key (create using e.g. ssh-keygen -f /perm/breakglass.host_key -N '' -t rsa)") @@ -45,6 +51,9 @@ var ( forwarding = flag.String("forward", "", "allow port forwarding. Use `loopback` for loopback interfaces and `private-network` for private networks") + + home = "/perm/home" + shell = "" ) func loadAuthorizedKeys(path string) (map[string]bool, error) { @@ -80,6 +89,19 @@ func loadAuthorizedKeys(path string) (map[string]bool, error) { return result, nil } +func loadPasswd(passwd string) { + b, err := os.ReadFile(passwd) + if err != nil { + return + } + fields := bytes.SplitN(bytes.SplitN(b, []byte("\n"), 2)[0], []byte(":"), 7) + if len(fields) != 7 { + return + } + home = path.Clean(string(fields[5])) + shell = path.Clean(string(fields[6])) +} + func loadHostKey(path string) (ssh.Signer, error) { b, err := ioutil.ReadFile(path) if err != nil { @@ -177,7 +199,7 @@ func initMOTD() error { return err } - motd = fmt.Sprintf(` __ + motd = fmt.Sprintf(` __ .-----.-----| |--.----.---.-.-----.--.--. | _ | _ | <| _| _ |-- __| | | |___ |_____|__|__|__| |___._|_____|___ | @@ -197,6 +219,8 @@ func main() { gokrazy.DontStartOnBoot() + loadPasswd("/etc/passwd") + authorizedKeys, err := loadAuthorizedKeys(*authorizedKeysPath) if err != nil { if os.IsNotExist(err) { @@ -205,19 +229,59 @@ func main() { log.Fatal(err) } + authorizedUserCertificateCA, err := loadAuthorizedKeys(strings.TrimPrefix(*authorizedUserCAPath, "ec2")) + if err != nil { + if os.IsNotExist(err) { + log.Printf("TrustedUserCAKeys not loaded") + } + } + if err := initMOTD(); err != nil { log.Print(err) } - - config := &ssh.ServerConfig{ - PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + certChecker := ssh.CertChecker{ + IsUserAuthority: func(auth ssh.PublicKey) bool { + return authorizedUserCertificateCA[string(auth.Marshal())] + }, + UserKeyFallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { if authorizedKeys[string(pubKey.Marshal())] { log.Printf("user %q successfully authorized from remote addr %s", conn.User(), conn.RemoteAddr()) - return nil, nil + return &ssh.Permissions{map[string]string{}, map[string]string{}}, nil } return nil, fmt.Errorf("public key not found in %s", *authorizedKeysPath) }, } + config := &ssh.ServerConfig{ + PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { + cert, ok := key.(*ssh.Certificate) + if !ok { + if certChecker.UserKeyFallback != nil { + return certChecker.UserKeyFallback(conn, key) + } + return nil, errors.New("ssh: normal key pairs not accepted") + } + + if cert.CertType != ssh.UserCert { + return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) + } + if !certChecker.IsUserAuthority(cert.SignatureKey) { + return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority") + } + + if err := certChecker.CheckCert(":gokrazy:", cert); err != nil { + return nil, err + } + if cert.Permissions.CriticalOptions == nil { + cert.Permissions.CriticalOptions = map[string]string{} + } + if cert.Permissions.Extensions == nil { + cert.Permissions.Extensions = map[string]string{} + } + + return &cert.Permissions, nil + + }, + } signer, err := loadHostKey(*hostKeyPath) if err != nil { @@ -267,7 +331,7 @@ func main() { } go func(conn net.Conn) { - _, chans, reqs, err := ssh.NewServerConn(conn, config) + c, chans, reqs, err := ssh.NewServerConn(conn, config) if err != nil { log.Printf("handshake: %v", err) return @@ -277,7 +341,7 @@ func main() { go ssh.DiscardRequests(reqs) for newChannel := range chans { - handleChannel(newChannel) + handleChannel(newChannel, c) } }(conn) } diff --git a/ssh.go b/ssh.go index fdaa859..63f531a 100644 --- a/ssh.go +++ b/ssh.go @@ -2,12 +2,14 @@ package main import ( "context" + "errors" "fmt" "io" "log" "net" "os" "os/exec" + "path" "strconv" "strings" "sync" @@ -21,11 +23,15 @@ import ( "golang.org/x/crypto/ssh" ) -func handleChannel(newChan ssh.NewChannel) { +func handleChannel(newChan ssh.NewChannel, conn *ssh.ServerConn) { switch t := newChan.ChannelType(); t { case "session": - handleSession(newChan) + handleSession(newChan, conn) case "direct-tcpip": + if _, portForwardDenied := conn.Permissions.Extensions["no-port-forwarding"]; portForwardDenied { + newChan.Reject(ssh.Prohibited, "port forwarding is disabled. For you in particular :-P") + return + } handleTCPIP(newChan) default: newChan.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %q", t)) @@ -112,7 +118,7 @@ func handleTCPIP(newChan ssh.NewChannel) { }() } -func handleSession(newChannel ssh.NewChannel) { +func handleSession(newChannel ssh.NewChannel, conn *ssh.ServerConn) { channel, requests, err := newChannel.Accept() if err != nil { log.Printf("Could not accept channel (%s)", err) @@ -120,12 +126,12 @@ func handleSession(newChannel ssh.NewChannel) { } // Sessions have out-of-band requests such as "shell", "pty-req" and "env" - go func(channel ssh.Channel, requests <-chan *ssh.Request) { + go func(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) { ctx, canc := context.WithCancel(context.Background()) defer canc() s := session{channel: channel} for req := range requests { - if err := s.request(ctx, req); err != nil { + if err := s.request(ctx, req, conn); err != nil { log.Printf("request(%q): %v", req.Type, err) errmsg := []byte(err.Error()) // Append a trailing newline; the error message is @@ -139,7 +145,7 @@ func handleSession(newChannel ssh.NewChannel) { } } log.Printf("requests exhausted") - }(channel, requests) + }(channel, requests, conn) } func expandPath(env []string) []string { @@ -220,24 +226,31 @@ func findShell() string { // in standard locations (makes Emacs TRAMP work, for example). if err := installBusybox(); err != nil { log.Printf("installing busybox failed: %v", err) - // fallthrough - } else { - return "/bin/sh" // available after installation + // fallthrough, we don't return /bin/sh as we read /etc/passwd } } + if _, err := exec.LookPath(shell); path.IsAbs(shell) && err == nil { + return shell + } + if path, err := exec.LookPath("bash"); err == nil { + return path + } if path, err := exec.LookPath("sh"); err == nil { return path } const wellKnownSerialShell = "/tmp/serial-busybox/ash" - if _, err := os.Stat(wellKnownSerialShell); err == nil { + if _, err := exec.LookPath(wellKnownSerialShell); err == nil { return wellKnownSerialShell } return "" } -func (s *session) request(ctx context.Context, req *ssh.Request) error { +func (s *session) request(ctx context.Context, req *ssh.Request, conn *ssh.ServerConn) error { switch req.Type { case "pty-req": + if _, portForwardDenied := conn.Permissions.Extensions["no-pty"]; portForwardDenied { + return errors.New("Pseudo-Terminal is disabled. For you in particular :-P") + } var r ptyreq if err := ssh.Unmarshal(req.Payload, &r); err != nil { return err @@ -355,21 +368,25 @@ func (s *session) request(ctx context.Context, req *ssh.Request) error { // Ensure the $HOME directory exists so that shell history works without // any extra steps. - if err := os.MkdirAll("/perm/home", 0755); err != nil { + if err := os.MkdirAll(home, 0755); err != nil { // TODO: Suppress -EROFS log.Print(err) } var cmd *exec.Cmd if shell := findShell(); shell != "" { - cmd = exec.CommandContext(ctx, shell, "-c", r.Command) + if r.Command == "sh" { + cmd = exec.CommandContext(ctx, shell, "-l") + } else { + cmd = exec.CommandContext(ctx, shell, "-c", r.Command) + } } else { cmd = exec.CommandContext(ctx, cmdline[0], cmdline[1:]...) } log.Printf("Starting cmd %q", cmd.Args) env := expandPath(s.env) env = append(env, - "HOME=/perm/home", + "HOME="+home, "TMPDIR=/tmp") cmd.Env = env cmd.SysProcAttr = &syscall.SysProcAttr{}