Stuff
Some checks failed
Push / CI (push) Has been cancelled

Implement certificate authentication, certificate requires :gokrazy: principal
Read first line of /etc/passwd for home and shell
Shell uses `-l` to make it a login shell which will run .profile
This commit is contained in:
Timmy Welch 2025-02-16 17:53:26 -08:00
parent 86e60e7477
commit bba58e7a3a
2 changed files with 102 additions and 21 deletions

View File

@ -10,6 +10,7 @@ import (
"crypto/x509" "crypto/x509"
"encoding/json" "encoding/json"
"encoding/pem" "encoding/pem"
"errors"
"flag" "flag"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -17,6 +18,7 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"path"
"strings" "strings"
"syscall" "syscall"
@ -30,6 +32,10 @@ var (
"/perm/breakglass.authorized_keys", "/perm/breakglass.authorized_keys",
"path to an OpenSSH authorized_keys file; if the value is 'ec2', fetch the SSH key(s) from the AWS IMDSv2 metadata") "path to an OpenSSH authorized_keys file; if the value is 'ec2', fetch the SSH key(s) from the AWS IMDSv2 metadata")
authorizedUserCAPath = flag.String("authorized_ca",
"/perm/breakglass.authorized_user_ca",
"path to an OpenSSH TrustedUserCAKeys file; note the certificate must list ':gokrazy:' as a valid principal")
hostKeyPath = flag.String("host_key", hostKeyPath = flag.String("host_key",
"/perm/breakglass.host_key", "/perm/breakglass.host_key",
"path to a PEM-encoded RSA, DSA or ECDSA private key (create using e.g. ssh-keygen -f /perm/breakglass.host_key -N '' -t rsa)") "path to a PEM-encoded RSA, DSA or ECDSA private key (create using e.g. ssh-keygen -f /perm/breakglass.host_key -N '' -t rsa)")
@ -45,6 +51,9 @@ var (
forwarding = flag.String("forward", forwarding = flag.String("forward",
"", "",
"allow port forwarding. Use `loopback` for loopback interfaces and `private-network` for private networks") "allow port forwarding. Use `loopback` for loopback interfaces and `private-network` for private networks")
home = "/perm/home"
shell = ""
) )
func loadAuthorizedKeys(path string) (map[string]bool, error) { func loadAuthorizedKeys(path string) (map[string]bool, error) {
@ -80,6 +89,19 @@ func loadAuthorizedKeys(path string) (map[string]bool, error) {
return result, nil return result, nil
} }
func loadPasswd(passwd string) {
b, err := os.ReadFile(passwd)
if err != nil {
return
}
fields := bytes.SplitN(bytes.SplitN(b, []byte("\n"), 2)[0], []byte(":"), 7)
if len(fields) != 7 {
return
}
home = path.Clean(string(fields[5]))
shell = path.Clean(string(fields[6]))
}
func loadHostKey(path string) (ssh.Signer, error) { func loadHostKey(path string) (ssh.Signer, error) {
b, err := ioutil.ReadFile(path) b, err := ioutil.ReadFile(path)
if err != nil { if err != nil {
@ -197,6 +219,8 @@ func main() {
gokrazy.DontStartOnBoot() gokrazy.DontStartOnBoot()
loadPasswd("/etc/passwd")
authorizedKeys, err := loadAuthorizedKeys(*authorizedKeysPath) authorizedKeys, err := loadAuthorizedKeys(*authorizedKeysPath)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -205,19 +229,59 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
authorizedUserCertificateCA, err := loadAuthorizedKeys(strings.TrimPrefix(*authorizedUserCAPath, "ec2"))
if err != nil {
if os.IsNotExist(err) {
log.Printf("TrustedUserCAKeys not loaded")
}
}
if err := initMOTD(); err != nil { if err := initMOTD(); err != nil {
log.Print(err) log.Print(err)
} }
certChecker := ssh.CertChecker{
config := &ssh.ServerConfig{ IsUserAuthority: func(auth ssh.PublicKey) bool {
PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { return authorizedUserCertificateCA[string(auth.Marshal())]
},
UserKeyFallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
if authorizedKeys[string(pubKey.Marshal())] { if authorizedKeys[string(pubKey.Marshal())] {
log.Printf("user %q successfully authorized from remote addr %s", conn.User(), conn.RemoteAddr()) log.Printf("user %q successfully authorized from remote addr %s", conn.User(), conn.RemoteAddr())
return nil, nil return &ssh.Permissions{map[string]string{}, map[string]string{}}, nil
} }
return nil, fmt.Errorf("public key not found in %s", *authorizedKeysPath) return nil, fmt.Errorf("public key not found in %s", *authorizedKeysPath)
}, },
} }
config := &ssh.ServerConfig{
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
cert, ok := key.(*ssh.Certificate)
if !ok {
if certChecker.UserKeyFallback != nil {
return certChecker.UserKeyFallback(conn, key)
}
return nil, errors.New("ssh: normal key pairs not accepted")
}
if cert.CertType != ssh.UserCert {
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType)
}
if !certChecker.IsUserAuthority(cert.SignatureKey) {
return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority")
}
if err := certChecker.CheckCert(":gokrazy:", cert); err != nil {
return nil, err
}
if cert.Permissions.CriticalOptions == nil {
cert.Permissions.CriticalOptions = map[string]string{}
}
if cert.Permissions.Extensions == nil {
cert.Permissions.Extensions = map[string]string{}
}
return &cert.Permissions, nil
},
}
signer, err := loadHostKey(*hostKeyPath) signer, err := loadHostKey(*hostKeyPath)
if err != nil { if err != nil {
@ -267,7 +331,7 @@ func main() {
} }
go func(conn net.Conn) { go func(conn net.Conn) {
_, chans, reqs, err := ssh.NewServerConn(conn, config) c, chans, reqs, err := ssh.NewServerConn(conn, config)
if err != nil { if err != nil {
log.Printf("handshake: %v", err) log.Printf("handshake: %v", err)
return return
@ -277,7 +341,7 @@ func main() {
go ssh.DiscardRequests(reqs) go ssh.DiscardRequests(reqs)
for newChannel := range chans { for newChannel := range chans {
handleChannel(newChannel) handleChannel(newChannel, c)
} }
}(conn) }(conn)
} }

45
ssh.go
View File

@ -2,12 +2,14 @@ package main
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"net" "net"
"os" "os"
"os/exec" "os/exec"
"path"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -21,11 +23,15 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
func handleChannel(newChan ssh.NewChannel) { func handleChannel(newChan ssh.NewChannel, conn *ssh.ServerConn) {
switch t := newChan.ChannelType(); t { switch t := newChan.ChannelType(); t {
case "session": case "session":
handleSession(newChan) handleSession(newChan, conn)
case "direct-tcpip": case "direct-tcpip":
if _, portForwardDenied := conn.Permissions.Extensions["no-port-forwarding"]; portForwardDenied {
newChan.Reject(ssh.Prohibited, "port forwarding is disabled. For you in particular :-P")
return
}
handleTCPIP(newChan) handleTCPIP(newChan)
default: default:
newChan.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %q", t)) newChan.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %q", t))
@ -112,7 +118,7 @@ func handleTCPIP(newChan ssh.NewChannel) {
}() }()
} }
func handleSession(newChannel ssh.NewChannel) { func handleSession(newChannel ssh.NewChannel, conn *ssh.ServerConn) {
channel, requests, err := newChannel.Accept() channel, requests, err := newChannel.Accept()
if err != nil { if err != nil {
log.Printf("Could not accept channel (%s)", err) log.Printf("Could not accept channel (%s)", err)
@ -120,12 +126,12 @@ func handleSession(newChannel ssh.NewChannel) {
} }
// Sessions have out-of-band requests such as "shell", "pty-req" and "env" // Sessions have out-of-band requests such as "shell", "pty-req" and "env"
go func(channel ssh.Channel, requests <-chan *ssh.Request) { go func(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) {
ctx, canc := context.WithCancel(context.Background()) ctx, canc := context.WithCancel(context.Background())
defer canc() defer canc()
s := session{channel: channel} s := session{channel: channel}
for req := range requests { for req := range requests {
if err := s.request(ctx, req); err != nil { if err := s.request(ctx, req, conn); err != nil {
log.Printf("request(%q): %v", req.Type, err) log.Printf("request(%q): %v", req.Type, err)
errmsg := []byte(err.Error()) errmsg := []byte(err.Error())
// Append a trailing newline; the error message is // Append a trailing newline; the error message is
@ -139,7 +145,7 @@ func handleSession(newChannel ssh.NewChannel) {
} }
} }
log.Printf("requests exhausted") log.Printf("requests exhausted")
}(channel, requests) }(channel, requests, conn)
} }
func expandPath(env []string) []string { func expandPath(env []string) []string {
@ -220,24 +226,31 @@ func findShell() string {
// in standard locations (makes Emacs TRAMP work, for example). // in standard locations (makes Emacs TRAMP work, for example).
if err := installBusybox(); err != nil { if err := installBusybox(); err != nil {
log.Printf("installing busybox failed: %v", err) log.Printf("installing busybox failed: %v", err)
// fallthrough // fallthrough, we don't return /bin/sh as we read /etc/passwd
} else {
return "/bin/sh" // available after installation
} }
} }
if _, err := exec.LookPath(shell); path.IsAbs(shell) && err == nil {
return shell
}
if path, err := exec.LookPath("bash"); err == nil {
return path
}
if path, err := exec.LookPath("sh"); err == nil { if path, err := exec.LookPath("sh"); err == nil {
return path return path
} }
const wellKnownSerialShell = "/tmp/serial-busybox/ash" const wellKnownSerialShell = "/tmp/serial-busybox/ash"
if _, err := os.Stat(wellKnownSerialShell); err == nil { if _, err := exec.LookPath(wellKnownSerialShell); err == nil {
return wellKnownSerialShell return wellKnownSerialShell
} }
return "" return ""
} }
func (s *session) request(ctx context.Context, req *ssh.Request) error { func (s *session) request(ctx context.Context, req *ssh.Request, conn *ssh.ServerConn) error {
switch req.Type { switch req.Type {
case "pty-req": case "pty-req":
if _, portForwardDenied := conn.Permissions.Extensions["no-pty"]; portForwardDenied {
return errors.New("Pseudo-Terminal is disabled. For you in particular :-P")
}
var r ptyreq var r ptyreq
if err := ssh.Unmarshal(req.Payload, &r); err != nil { if err := ssh.Unmarshal(req.Payload, &r); err != nil {
return err return err
@ -355,21 +368,25 @@ func (s *session) request(ctx context.Context, req *ssh.Request) error {
// Ensure the $HOME directory exists so that shell history works without // Ensure the $HOME directory exists so that shell history works without
// any extra steps. // any extra steps.
if err := os.MkdirAll("/perm/home", 0755); err != nil { if err := os.MkdirAll(home, 0755); err != nil {
// TODO: Suppress -EROFS // TODO: Suppress -EROFS
log.Print(err) log.Print(err)
} }
var cmd *exec.Cmd var cmd *exec.Cmd
if shell := findShell(); shell != "" { if shell := findShell(); shell != "" {
cmd = exec.CommandContext(ctx, shell, "-c", r.Command) if r.Command == "sh" {
cmd = exec.CommandContext(ctx, shell, "-l")
} else {
cmd = exec.CommandContext(ctx, shell, "-c", r.Command)
}
} else { } else {
cmd = exec.CommandContext(ctx, cmdline[0], cmdline[1:]...) cmd = exec.CommandContext(ctx, cmdline[0], cmdline[1:]...)
} }
log.Printf("Starting cmd %q", cmd.Args) log.Printf("Starting cmd %q", cmd.Args)
env := expandPath(s.env) env := expandPath(s.env)
env = append(env, env = append(env,
"HOME=/perm/home", "HOME="+home,
"TMPDIR=/tmp") "TMPDIR=/tmp")
cmd.Env = env cmd.Env = env
cmd.SysProcAttr = &syscall.SysProcAttr{} cmd.SysProcAttr = &syscall.SysProcAttr{}