Implement certificate authentication, certificate requires :gokrazy: principal Read first line of /etc/passwd for home and shell Shell uses `-l` to make it a login shell which will run .profile
This commit is contained in:
parent
86e60e7477
commit
bba58e7a3a
@ -10,6 +10,7 @@ import (
|
|||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
@ -17,6 +18,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
@ -30,6 +32,10 @@ var (
|
|||||||
"/perm/breakglass.authorized_keys",
|
"/perm/breakglass.authorized_keys",
|
||||||
"path to an OpenSSH authorized_keys file; if the value is 'ec2', fetch the SSH key(s) from the AWS IMDSv2 metadata")
|
"path to an OpenSSH authorized_keys file; if the value is 'ec2', fetch the SSH key(s) from the AWS IMDSv2 metadata")
|
||||||
|
|
||||||
|
authorizedUserCAPath = flag.String("authorized_ca",
|
||||||
|
"/perm/breakglass.authorized_user_ca",
|
||||||
|
"path to an OpenSSH TrustedUserCAKeys file; note the certificate must list ':gokrazy:' as a valid principal")
|
||||||
|
|
||||||
hostKeyPath = flag.String("host_key",
|
hostKeyPath = flag.String("host_key",
|
||||||
"/perm/breakglass.host_key",
|
"/perm/breakglass.host_key",
|
||||||
"path to a PEM-encoded RSA, DSA or ECDSA private key (create using e.g. ssh-keygen -f /perm/breakglass.host_key -N '' -t rsa)")
|
"path to a PEM-encoded RSA, DSA or ECDSA private key (create using e.g. ssh-keygen -f /perm/breakglass.host_key -N '' -t rsa)")
|
||||||
@ -45,6 +51,9 @@ var (
|
|||||||
forwarding = flag.String("forward",
|
forwarding = flag.String("forward",
|
||||||
"",
|
"",
|
||||||
"allow port forwarding. Use `loopback` for loopback interfaces and `private-network` for private networks")
|
"allow port forwarding. Use `loopback` for loopback interfaces and `private-network` for private networks")
|
||||||
|
|
||||||
|
home = "/perm/home"
|
||||||
|
shell = ""
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadAuthorizedKeys(path string) (map[string]bool, error) {
|
func loadAuthorizedKeys(path string) (map[string]bool, error) {
|
||||||
@ -80,6 +89,19 @@ func loadAuthorizedKeys(path string) (map[string]bool, error) {
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func loadPasswd(passwd string) {
|
||||||
|
b, err := os.ReadFile(passwd)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fields := bytes.SplitN(bytes.SplitN(b, []byte("\n"), 2)[0], []byte(":"), 7)
|
||||||
|
if len(fields) != 7 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
home = path.Clean(string(fields[5]))
|
||||||
|
shell = path.Clean(string(fields[6]))
|
||||||
|
}
|
||||||
|
|
||||||
func loadHostKey(path string) (ssh.Signer, error) {
|
func loadHostKey(path string) (ssh.Signer, error) {
|
||||||
b, err := ioutil.ReadFile(path)
|
b, err := ioutil.ReadFile(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -197,6 +219,8 @@ func main() {
|
|||||||
|
|
||||||
gokrazy.DontStartOnBoot()
|
gokrazy.DontStartOnBoot()
|
||||||
|
|
||||||
|
loadPasswd("/etc/passwd")
|
||||||
|
|
||||||
authorizedKeys, err := loadAuthorizedKeys(*authorizedKeysPath)
|
authorizedKeys, err := loadAuthorizedKeys(*authorizedKeysPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
@ -205,19 +229,59 @@ func main() {
|
|||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
authorizedUserCertificateCA, err := loadAuthorizedKeys(strings.TrimPrefix(*authorizedUserCAPath, "ec2"))
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
log.Printf("TrustedUserCAKeys not loaded")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := initMOTD(); err != nil {
|
if err := initMOTD(); err != nil {
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
}
|
}
|
||||||
|
certChecker := ssh.CertChecker{
|
||||||
config := &ssh.ServerConfig{
|
IsUserAuthority: func(auth ssh.PublicKey) bool {
|
||||||
PublicKeyCallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
return authorizedUserCertificateCA[string(auth.Marshal())]
|
||||||
|
},
|
||||||
|
UserKeyFallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
||||||
if authorizedKeys[string(pubKey.Marshal())] {
|
if authorizedKeys[string(pubKey.Marshal())] {
|
||||||
log.Printf("user %q successfully authorized from remote addr %s", conn.User(), conn.RemoteAddr())
|
log.Printf("user %q successfully authorized from remote addr %s", conn.User(), conn.RemoteAddr())
|
||||||
return nil, nil
|
return &ssh.Permissions{map[string]string{}, map[string]string{}}, nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("public key not found in %s", *authorizedKeysPath)
|
return nil, fmt.Errorf("public key not found in %s", *authorizedKeysPath)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
config := &ssh.ServerConfig{
|
||||||
|
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||||
|
cert, ok := key.(*ssh.Certificate)
|
||||||
|
if !ok {
|
||||||
|
if certChecker.UserKeyFallback != nil {
|
||||||
|
return certChecker.UserKeyFallback(conn, key)
|
||||||
|
}
|
||||||
|
return nil, errors.New("ssh: normal key pairs not accepted")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cert.CertType != ssh.UserCert {
|
||||||
|
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType)
|
||||||
|
}
|
||||||
|
if !certChecker.IsUserAuthority(cert.SignatureKey) {
|
||||||
|
return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := certChecker.CheckCert(":gokrazy:", cert); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if cert.Permissions.CriticalOptions == nil {
|
||||||
|
cert.Permissions.CriticalOptions = map[string]string{}
|
||||||
|
}
|
||||||
|
if cert.Permissions.Extensions == nil {
|
||||||
|
cert.Permissions.Extensions = map[string]string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &cert.Permissions, nil
|
||||||
|
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
signer, err := loadHostKey(*hostKeyPath)
|
signer, err := loadHostKey(*hostKeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -267,7 +331,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func(conn net.Conn) {
|
go func(conn net.Conn) {
|
||||||
_, chans, reqs, err := ssh.NewServerConn(conn, config)
|
c, chans, reqs, err := ssh.NewServerConn(conn, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("handshake: %v", err)
|
log.Printf("handshake: %v", err)
|
||||||
return
|
return
|
||||||
@ -277,7 +341,7 @@ func main() {
|
|||||||
go ssh.DiscardRequests(reqs)
|
go ssh.DiscardRequests(reqs)
|
||||||
|
|
||||||
for newChannel := range chans {
|
for newChannel := range chans {
|
||||||
handleChannel(newChannel)
|
handleChannel(newChannel, c)
|
||||||
}
|
}
|
||||||
}(conn)
|
}(conn)
|
||||||
}
|
}
|
||||||
|
43
ssh.go
43
ssh.go
@ -2,12 +2,14 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"path"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@ -21,11 +23,15 @@ import (
|
|||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
func handleChannel(newChan ssh.NewChannel) {
|
func handleChannel(newChan ssh.NewChannel, conn *ssh.ServerConn) {
|
||||||
switch t := newChan.ChannelType(); t {
|
switch t := newChan.ChannelType(); t {
|
||||||
case "session":
|
case "session":
|
||||||
handleSession(newChan)
|
handleSession(newChan, conn)
|
||||||
case "direct-tcpip":
|
case "direct-tcpip":
|
||||||
|
if _, portForwardDenied := conn.Permissions.Extensions["no-port-forwarding"]; portForwardDenied {
|
||||||
|
newChan.Reject(ssh.Prohibited, "port forwarding is disabled. For you in particular :-P")
|
||||||
|
return
|
||||||
|
}
|
||||||
handleTCPIP(newChan)
|
handleTCPIP(newChan)
|
||||||
default:
|
default:
|
||||||
newChan.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %q", t))
|
newChan.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %q", t))
|
||||||
@ -112,7 +118,7 @@ func handleTCPIP(newChan ssh.NewChannel) {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleSession(newChannel ssh.NewChannel) {
|
func handleSession(newChannel ssh.NewChannel, conn *ssh.ServerConn) {
|
||||||
channel, requests, err := newChannel.Accept()
|
channel, requests, err := newChannel.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Could not accept channel (%s)", err)
|
log.Printf("Could not accept channel (%s)", err)
|
||||||
@ -120,12 +126,12 @@ func handleSession(newChannel ssh.NewChannel) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Sessions have out-of-band requests such as "shell", "pty-req" and "env"
|
// Sessions have out-of-band requests such as "shell", "pty-req" and "env"
|
||||||
go func(channel ssh.Channel, requests <-chan *ssh.Request) {
|
go func(channel ssh.Channel, requests <-chan *ssh.Request, conn *ssh.ServerConn) {
|
||||||
ctx, canc := context.WithCancel(context.Background())
|
ctx, canc := context.WithCancel(context.Background())
|
||||||
defer canc()
|
defer canc()
|
||||||
s := session{channel: channel}
|
s := session{channel: channel}
|
||||||
for req := range requests {
|
for req := range requests {
|
||||||
if err := s.request(ctx, req); err != nil {
|
if err := s.request(ctx, req, conn); err != nil {
|
||||||
log.Printf("request(%q): %v", req.Type, err)
|
log.Printf("request(%q): %v", req.Type, err)
|
||||||
errmsg := []byte(err.Error())
|
errmsg := []byte(err.Error())
|
||||||
// Append a trailing newline; the error message is
|
// Append a trailing newline; the error message is
|
||||||
@ -139,7 +145,7 @@ func handleSession(newChannel ssh.NewChannel) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Printf("requests exhausted")
|
log.Printf("requests exhausted")
|
||||||
}(channel, requests)
|
}(channel, requests, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
func expandPath(env []string) []string {
|
func expandPath(env []string) []string {
|
||||||
@ -220,24 +226,31 @@ func findShell() string {
|
|||||||
// in standard locations (makes Emacs TRAMP work, for example).
|
// in standard locations (makes Emacs TRAMP work, for example).
|
||||||
if err := installBusybox(); err != nil {
|
if err := installBusybox(); err != nil {
|
||||||
log.Printf("installing busybox failed: %v", err)
|
log.Printf("installing busybox failed: %v", err)
|
||||||
// fallthrough
|
// fallthrough, we don't return /bin/sh as we read /etc/passwd
|
||||||
} else {
|
|
||||||
return "/bin/sh" // available after installation
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if _, err := exec.LookPath(shell); path.IsAbs(shell) && err == nil {
|
||||||
|
return shell
|
||||||
|
}
|
||||||
|
if path, err := exec.LookPath("bash"); err == nil {
|
||||||
|
return path
|
||||||
|
}
|
||||||
if path, err := exec.LookPath("sh"); err == nil {
|
if path, err := exec.LookPath("sh"); err == nil {
|
||||||
return path
|
return path
|
||||||
}
|
}
|
||||||
const wellKnownSerialShell = "/tmp/serial-busybox/ash"
|
const wellKnownSerialShell = "/tmp/serial-busybox/ash"
|
||||||
if _, err := os.Stat(wellKnownSerialShell); err == nil {
|
if _, err := exec.LookPath(wellKnownSerialShell); err == nil {
|
||||||
return wellKnownSerialShell
|
return wellKnownSerialShell
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *session) request(ctx context.Context, req *ssh.Request) error {
|
func (s *session) request(ctx context.Context, req *ssh.Request, conn *ssh.ServerConn) error {
|
||||||
switch req.Type {
|
switch req.Type {
|
||||||
case "pty-req":
|
case "pty-req":
|
||||||
|
if _, portForwardDenied := conn.Permissions.Extensions["no-pty"]; portForwardDenied {
|
||||||
|
return errors.New("Pseudo-Terminal is disabled. For you in particular :-P")
|
||||||
|
}
|
||||||
var r ptyreq
|
var r ptyreq
|
||||||
if err := ssh.Unmarshal(req.Payload, &r); err != nil {
|
if err := ssh.Unmarshal(req.Payload, &r); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -355,21 +368,25 @@ func (s *session) request(ctx context.Context, req *ssh.Request) error {
|
|||||||
|
|
||||||
// Ensure the $HOME directory exists so that shell history works without
|
// Ensure the $HOME directory exists so that shell history works without
|
||||||
// any extra steps.
|
// any extra steps.
|
||||||
if err := os.MkdirAll("/perm/home", 0755); err != nil {
|
if err := os.MkdirAll(home, 0755); err != nil {
|
||||||
// TODO: Suppress -EROFS
|
// TODO: Suppress -EROFS
|
||||||
log.Print(err)
|
log.Print(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var cmd *exec.Cmd
|
var cmd *exec.Cmd
|
||||||
if shell := findShell(); shell != "" {
|
if shell := findShell(); shell != "" {
|
||||||
|
if r.Command == "sh" {
|
||||||
|
cmd = exec.CommandContext(ctx, shell, "-l")
|
||||||
|
} else {
|
||||||
cmd = exec.CommandContext(ctx, shell, "-c", r.Command)
|
cmd = exec.CommandContext(ctx, shell, "-c", r.Command)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
cmd = exec.CommandContext(ctx, cmdline[0], cmdline[1:]...)
|
cmd = exec.CommandContext(ctx, cmdline[0], cmdline[1:]...)
|
||||||
}
|
}
|
||||||
log.Printf("Starting cmd %q", cmd.Args)
|
log.Printf("Starting cmd %q", cmd.Args)
|
||||||
env := expandPath(s.env)
|
env := expandPath(s.env)
|
||||||
env = append(env,
|
env = append(env,
|
||||||
"HOME=/perm/home",
|
"HOME="+home,
|
||||||
"TMPDIR=/tmp")
|
"TMPDIR=/tmp")
|
||||||
cmd.Env = env
|
cmd.Env = env
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{}
|
cmd.SysProcAttr = &syscall.SysProcAttr{}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user