5 Commits

Author SHA1 Message Date
Timmy Welch
c570b14537 Setup logging after opening socket so files aren't rotated early 2026-01-14 15:56:48 -08:00
Timmy Welch
58016c6889 Fix output on startup and TERM signal 2026-01-10 14:06:19 -08:00
Timmy Welch
8989dc25ac Retrieve CA Key for display on startup 2026-01-10 13:12:50 -08:00
Timmy Welch
899aad07b2 modernize 2026-01-10 12:56:08 -08:00
Timmy Welch
2bc73596d3 Updates
Better forking
Simplify socket checking
Use logger properly
Fix errors/panic's killing the agent
Ensure a user-agent is used for http requests
2026-01-10 12:55:27 -08:00
11 changed files with 355 additions and 176 deletions

View File

@@ -2,13 +2,12 @@ package main
import (
"crypto"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"net"
"os"
"os/signal"
@@ -17,13 +16,15 @@ import (
"strings"
"syscall"
"git.narnian.us/lordwelch/sshrimp/internal/config"
"git.narnian.us/lordwelch/sshrimp/internal/signer"
"git.narnian.us/lordwelch/sshrimp/internal/sshrimpagent"
"gitea.narnian.us/lordwelch/sshrimp/internal/config"
"gitea.narnian.us/lordwelch/sshrimp/internal/signer"
"gitea.narnian.us/lordwelch/sshrimp/internal/sshrimpagent"
"github.com/prometheus/procfs"
"github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/writer"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"inet.af/peercred"
)
var (
@@ -38,6 +39,7 @@ type cfg struct {
Config string
LogDirectory string
Verbose bool
Foreground bool
}
func getLogDir() string {
@@ -72,8 +74,8 @@ func setupLoging(config cfg) error {
logrus.ErrorLevel,
logrus.WarnLevel,
}
err := os.MkdirAll(config.LogDirectory, 0750)
if err != nil && !os.IsExist(err) {
err := os.MkdirAll(config.LogDirectory, 0o750)
if err != nil {
log.Fatal(err)
}
@@ -91,7 +93,7 @@ func setupLoging(config cfg) error {
Writer: os.Stderr,
LogLevels: levels,
})
logger.Out = ioutil.Discard
logger.Out = io.Discard
file, err := os.Create(logName)
if err != nil {
return err
@@ -122,38 +124,127 @@ func ExpandPath(path string) string {
return path
}
func main2(cli cfg, c *config.SSHrimp) {
listener := openSocket(ExpandPath(c.Agent.Socket))
if listener == nil {
log.Errorln("Failed to open socket")
return
}
err := setupLoging(cli)
if err != nil {
log.Warnf("Error setting up logging: %v", err)
}
err = launchAgent(c, listener)
if err != nil {
log.Panic("Failed to launch agent", err)
}
}
func main() {
defaultConfigPath := "~/.ssh/sshrimp.toml"
if configPathFromEnv, ok := os.LookupEnv("SSHRIMP_CONFIG"); ok && configPathFromEnv != "" {
defaultConfigPath = configPathFromEnv
}
var cli cfg
var (
cli cfg
err error
)
flag.StringVar(&cli.Config, "config", defaultConfigPath, "sshrimp config file")
flag.StringVar(&cli.LogDirectory, "log", getLogDir(), "sshrimp log directory")
flag.BoolVar(&cli.Verbose, "v", false, "enable verbose logging")
fmt.Println(getLogDir())
flag.BoolVar(&cli.Foreground, "f", false, "Run in the foreground")
flag.Parse()
sshCommand := flag.Args()
if cli.Verbose {
logger.SetLevel(logrus.DebugLevel)
}
cfgFile := ExpandPath(cli.Config)
cfgFile, err = filepath.Abs(cfgFile)
if err != nil {
log.Errorln("config must be an absolute path")
os.Exit(1)
}
c := config.NewSSHrimpWithDefaults()
err := c.Read(ExpandPath(cli.Config))
err = c.Read(cfgFile)
if err != nil {
panic(err)
}
listener := openSocket(ExpandPath(c.Agent.Socket))
if listener == nil {
logger.Errorln("Failed to open socket")
return
if os.Getenv("SSHRIMP_DAEMON") == "true" {
cli.Foreground = true
}
if err := setupLoging(cli); err != nil {
logger.Warnf("Error setting up logging: %v", err)
if cli.Foreground {
logger.Println("Launching agent")
main2(cli, c)
} else {
logger.Debug("Attempting to start daemon")
var nullFile *os.File
nullFile, err = os.Open(os.DevNull)
if err != nil {
panic(err)
}
env := os.Environ()
env = append(env, "SSHRIMP_DAEMON=true")
executable, err := os.Executable()
if err != nil {
panic(err)
}
_, err = os.StartProcess(executable, os.Args, &os.ProcAttr{
Dir: filepath.Dir(cfgFile),
Env: env,
Files: []*os.File{nullFile, nullFile, nullFile},
Sys: &syscall.SysProcAttr{
// Chroot: d.Chroot,
Setsid: true,
},
})
if err != nil {
panic(err)
}
nullFile.Close()
logger.Debugf("Agent started in the background check %s for logs", getLogDir())
}
err = launchAgent(c, listener)
if err != nil {
panic(err)
if len(sshCommand) > 1 && filepath.Base(sshCommand[0]) == "ssh" {
syscall.Exec(sshCommand[0], sshCommand, os.Environ())
}
}
func socketWorks(path string) bool {
var (
pid int
cred *peercred.Creds
)
conn, sockErr := net.Dial("unix", path)
if sockErr != nil {
return false
}
if conn == nil {
return false
}
defer conn.Close()
cred, sockErr = peercred.Get(conn)
if sockErr != nil {
return false
}
var (
ok bool
process *os.Process
)
pid, ok = cred.PID()
if !ok {
return false
}
process, sockErr = os.FindProcess(pid)
if sockErr != nil {
return false
}
defer process.Release()
return process.Signal(syscall.SIGHUP) == nil
}
func openSocket(socketPath string) net.Listener {
var (
listener net.Listener
@@ -161,26 +252,17 @@ func openSocket(socketPath string) net.Listener {
logMessage string
)
if _, err = os.Stat(socketPath); err == nil {
fmt.Println("Creating socket")
fmt.Printf("File already exists at %s\n", socketPath)
conn, sockErr := net.Dial("unix", socketPath)
if conn == nil {
logMessage = "conn is nil"
}
if sockErr == nil { // socket is accepting connections
conn.Close()
fmt.Printf("socket %s already exists\n", socketPath)
return nil
}
fmt.Printf("Socket is not connected %s\n", logMessage)
err = os.Remove(socketPath)
if err == nil { // socket is not accepting connections, assuming safe to remove
fmt.Println("Deleting socket: success")
} else {
fmt.Println("Deleting socket: fail", err)
return nil
}
if socketWorks(socketPath) { // socket is accepting connections
log.Printf("socket %s already exists\n", socketPath)
return nil
}
log.Printf("Socket is not connected %s\n", logMessage)
err = os.Remove(socketPath)
if err == nil { // socket is not accepting connections, assuming safe to remove
log.Println("Deleting socket: success")
} else if !errors.Is(err, os.ErrNotExist) {
log.Println("Deleting socket: fail", err)
return nil
}
// This affects all files created for the process. Since this is a sensitive
@@ -188,12 +270,62 @@ func openSocket(socketPath string) net.Listener {
syscall.Umask(0o077)
listener, err = net.Listen("unix", socketPath)
if err != nil {
fmt.Println("Error opening socket:", err)
log.Println("Error opening socket:", err)
return nil
}
log.Println("Opened socket", socketPath)
return listener
}
func getConnectedProcess(conn net.Conn) string {
var (
cred *peercred.Creds
err error
)
cred, err = peercred.Get(conn)
if err != nil {
return ""
}
pid, ok := cred.PID()
if !ok {
return ""
}
var (
proc procfs.Proc
name string
)
proc, err = procfs.NewProc(pid)
if err != nil {
return fmt.Sprintf("pid %d", pid)
}
name, err = proc.Executable()
if err == nil {
return fmt.Sprintf("pid %d", pid)
}
return name
}
func handle(sshrimpAgent agent.Agent, conn net.Conn) (err error) {
defer func() {
panicErr := recover()
if panicErr != nil {
if err != nil {
err = fmt.Errorf("something panicked: %w: %v", err, panicErr)
return
}
err, _ = panicErr.(error)
return
}
}()
log.Infof("Serving agent to %s", getConnectedProcess(conn))
if err = agent.ServeAgent(sshrimpAgent, conn); err != nil && !errors.Is(err, io.EOF) {
log.Errorf("Error serving agent: %v", err)
return err
}
return err
}
func launchAgent(c *config.SSHrimp, listener net.Listener) error {
var (
err error
@@ -202,11 +334,11 @@ func launchAgent(c *config.SSHrimp, listener net.Listener) error {
)
defer listener.Close()
fmt.Printf("listening on %s\n", c.Agent.Socket)
log.Printf("listening on %s\n", c.Agent.Socket)
// Generate a new SSH private/public key pair
log.Tracef("Generating RSA %d ssh keys", 2048)
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
log.Tracef("Generating ed25519 ssh keys")
_, privateKey, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
return err
}
@@ -220,7 +352,7 @@ func launchAgent(c *config.SSHrimp, listener net.Listener) error {
log.Traceln("Creating new sshrimp agent from sshSigner and config")
sshrimpAgent, err := sshrimpagent.NewSSHrimpAgent(c, sshSigner)
if err != nil {
log.Logger.Errorf("Failed to create sshrimpAgent: %v", err)
log.Errorf("Failed to create sshrimpAgent: %v", err)
}
// Listen for signals so that we can close the listener and exit nicely
@@ -230,8 +362,9 @@ func launchAgent(c *config.SSHrimp, listener net.Listener) error {
osSignals := make(chan os.Signal, 10)
signal.Notify(osSignals, sigExit...)
go func() {
<-osSignals
listener.Close()
sig := <-osSignals
log.Infof("Recieved signal %v: closing", sig)
os.Exit(0)
}()
log.Traceln("Starting main loop")
@@ -243,14 +376,10 @@ func launchAgent(c *config.SSHrimp, listener net.Listener) error {
log.Errorf("Error accepting connection: %v", err)
if strings.Contains(err.Error(), "use of closed network connection") {
// Occurs if the user interrupts the agent with a ctrl-c signal
return nil
continue
}
return err
}
log.Traceln("Serving agent")
if err = agent.ServeAgent(sshrimpAgent, conn); err != nil && !errors.Is(err, io.EOF) {
log.Errorf("Error serving agent: %v", err)
return err
log.Errorf("Error accepting connection: %v", err)
}
go handle(sshrimpAgent, conn)
}
}

View File

@@ -1,5 +1,4 @@
//go:build darwin || linux
// +build darwin linux
package main

View File

@@ -14,14 +14,17 @@ import (
"log"
"net/http"
"os"
"strings"
"git.narnian.us/lordwelch/sshrimp/internal/config"
"git.narnian.us/lordwelch/sshrimp/internal/signer"
"gitea.narnian.us/lordwelch/sshrimp/internal/config"
"gitea.narnian.us/lordwelch/sshrimp/internal/signer"
"github.com/BurntSushi/toml"
gonanoid "github.com/matoous/go-nanoid/v2"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)
func httpError(w http.ResponseWriter, v interface{}, statusCode int) {
func httpError(w http.ResponseWriter, v any, statusCode int) {
var b bytes.Buffer
e := json.NewEncoder(&b)
_ = e.Encode(v)
@@ -31,16 +34,22 @@ func httpError(w http.ResponseWriter, v interface{}, statusCode int) {
type Server struct {
config *config.SSHrimp
Key ssh.Signer
Log *logrus.Logger
}
func NewServer(cfg *config.SSHrimp) (*Server, error) {
server := &Server{
config: cfg,
Log: logrus.New(),
}
return server, nil
server.Log.SetLevel(logrus.DebugLevel)
return server, server.LoadKey()
}
func (s *Server) LoadKey() error {
if s.config.CertificateAuthority.KeyPath == "" {
return fmt.Errorf("key path missing")
}
b, err := os.ReadFile(s.config.CertificateAuthority.KeyPath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
@@ -79,14 +88,18 @@ func (s *Server) GenerateKey() error {
// ServeHTTP handles a request to sign an SSH public key verified by an OpenIDConnect id_token
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
txid := gonanoid.Must()
log := s.Log.WithField("X-Request-ID", txid)
w.Header().Add("X-Request-ID", txid)
defer r.Body.Close()
if r.URL.Path == "/config" {
if strings.HasPrefix(r.URL.Path, "/config") {
io.Copy(io.Discard, r.Body)
new_config := *s.config
// new_config.CertificateAuthority
newConfig := *s.config
newConfig.CertificateAuthority = config.CertificateAuthority{}
w.Header().Add("Content-Type", "application/toml")
w.Header().Add("Content-Disposition", `attachment; filename="sshrimp.toml"`)
t := toml.NewEncoder(w)
_ = t.Encode(new_config)
_ = t.Encode(newConfig)
return
}
if r.Header.Get("Content-Type") != "application/json" {
@@ -103,7 +116,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
certificate, err := signer.ValidateRequest(event, s.config, r.Header.Get("Function-Execution-Id"), fmt.Sprintf("%s/%s/%s", os.Getenv("GCP_PROJECT"), os.Getenv("FUNCTION_REGION"), os.Getenv("FUNCTION_NAME")))
certificate, err := signer.ValidateRequest(log, event, s.config, txid, s.Key.PublicKey())
if err != nil {
httpError(w, signer.SSHrimpResult{Certificate: "", ErrorMessage: err.Error(), ErrorType: http.StatusText(http.StatusBadRequest)}, http.StatusBadRequest)
return
@@ -151,6 +164,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func main() {
cfgFile := flag.String("config", "/etc/sshrimp.toml", "Path to sshrimp.toml")
addr := flag.String("addr", "127.0.0.1:8080", "Address to listen on")
flag.Parse()
cfg := config.NewSSHrimp()
if err := cfg.Read(*cfgFile); err != nil {
log.Printf("Unable to read config file %s: %v", *cfgFile, err)

9
go.mod
View File

@@ -1,4 +1,4 @@
module git.narnian.us/lordwelch/sshrimp
module gitea.narnian.us/lordwelch/sshrimp
go 1.24.0
@@ -8,19 +8,24 @@ require (
github.com/BurntSushi/toml v1.6.0
github.com/coreos/go-oidc/v3 v3.17.0
github.com/google/uuid v1.6.0
github.com/matoous/go-nanoid/v2 v2.1.0
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/prometheus/procfs v0.19.2
github.com/sirupsen/logrus v1.9.3
github.com/zitadel/oidc v1.13.5
golang.org/x/crypto v0.46.0
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93
inet.af/peercred v0.0.0-20210906144145-0893ea02156a
)
require (
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/gorilla/schema v1.4.1 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/jeremija/gosubmit v0.2.8 // indirect
github.com/rs/cors v1.11.1 // indirect
github.com/stretchr/testify v1.11.1 // indirect
github.com/zitadel/logging v0.6.2 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/oauth2 v0.34.0 // indirect
golang.org/x/sys v0.40.0 // indirect

23
go.sum
View File

@@ -19,38 +19,47 @@ github.com/gorilla/schema v1.4.1 h1:jUg5hUjCSDZpNGLuXQOgIWGdlgrIdYvgQ0wZtdK1M3E=
github.com/gorilla/schema v1.4.1/go.mod h1:Dg5SSm5PV60mhF2NFaTV1xuYYj8tV8NOPRo4FggUMnM=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/jeremija/gosubmit v0.2.7 h1:At0OhGCFGPXyjPYAsCchoBUhE099pcBXmsb4iZqROIc=
github.com/jeremija/gosubmit v0.2.7/go.mod h1:Ui+HS073lCFREXBbdfrJzMB57OI/bdxTiLtrDHHhFPI=
github.com/jeremija/gosubmit v0.2.8 h1:mmSITBz9JxVtu8eqbN+zmmwX7Ij2RidQxhcwRVI4wqA=
github.com/jeremija/gosubmit v0.2.8/go.mod h1:Ui+HS073lCFREXBbdfrJzMB57OI/bdxTiLtrDHHhFPI=
github.com/matoous/go-nanoid/v2 v2.1.0 h1:P64+dmq21hhWdtvZfEAofnvJULaRR1Yib0+PnU669bE=
github.com/matoous/go-nanoid/v2 v2.1.0/go.mod h1:KlbGNQ+FhrUNIHUxZdL63t7tl4LaPkZNpUULS8H4uVM=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rs/cors v1.10.1 h1:L0uuZVXIKlI1SShY2nhFfo44TYvDPQ1w4oFkUJNfhyo=
github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA=
github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/zitadel/logging v0.3.4 h1:9hZsTjMMTE3X2LUi0xcF9Q9EdLo+FAezeu52ireBbHM=
github.com/zitadel/logging v0.3.4/go.mod h1:aPpLQhE+v6ocNK0TWrBrd363hZ95KcI17Q1ixAQwZF0=
github.com/zitadel/logging v0.6.2 h1:MW2kDDR0ieQynPZ0KIZPrh9ote2WkxfBif5QoARDQcU=
github.com/zitadel/logging v0.6.2/go.mod h1:z6VWLWUkJpnNVDSLzrPSQSQyttysKZ6bCRongw0ROK4=
github.com/zitadel/oidc v1.13.5 h1:7jhh68NGZitLqwLiVU9Dtwa4IraJPFF1vS+4UupO93U=
github.com/zitadel/oidc v1.13.5/go.mod h1:rHs1DhU3Sv3tnI6bQRVlFa3u0lCwtR7S21WHY+yXgPA=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20210301091718-77cc2087c03b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -59,3 +68,5 @@ gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
inet.af/peercred v0.0.0-20210906144145-0893ea02156a h1:qdkS8Q5/i10xU2ArJMKYhVa1DORzBfYS/qA2UK2jheg=
inet.af/peercred v0.0.0-20210906144145-0893ea02156a/go.mod h1:FjawnflS/udxX+SvpsMgZfdqx2aykOlkISeAsADi5IU=

View File

@@ -15,14 +15,14 @@ import (
// Agent config for the sshrimp-agent agent
type Agent struct {
ProviderURL string
ClientID string
ClientSecret string
Socket string
Scopes []string
KeyPath string
Port int
CAUrls []string
ProviderURL string
ClientID string
ClientSecret string
Socket string
Scopes []string
KeyPath string
Port int
CAUrls []string
}
// CertificateAuthority config for the sshrimp-ca lambda
@@ -88,7 +88,7 @@ func NewSSHrimpWithDefaults() *SSHrimp {
return &sshrimp
}
func validateInt(val interface{}) error {
func validateInt(val any) error {
if str, ok := val.(string); ok {
if _, err := strconv.Atoi(str); err != nil {
return err
@@ -100,7 +100,7 @@ func validateInt(val interface{}) error {
return nil
}
func validateURL(val interface{}) error {
func validateURL(val any) error {
if str, ok := val.(string); ok {
if _, err := url.ParseRequestURI(str); err != nil {
return err
@@ -112,7 +112,7 @@ func validateURL(val interface{}) error {
return nil
}
func validateDuration(val interface{}) error {
func validateDuration(val any) error {
if str, ok := val.(string); ok {
if _, err := time.ParseDuration(str); err != nil {
return err
@@ -124,7 +124,7 @@ func validateDuration(val interface{}) error {
return nil
}
func validateAlias(val interface{}) error {
func validateAlias(val any) error {
if str, ok := val.(string); ok {
if !strings.HasPrefix(str, "alias/") {
return errors.New("KMS alias must begin with alias/")

39
internal/http/client.go Normal file
View File

@@ -0,0 +1,39 @@
package http
import (
"net"
"net/http"
"time"
)
type Transport struct {
http.RoundTripper
UserAgent string
}
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
if req.Header.Get("User-Agent") == "" {
req.Header.Set("User-Agent", t.UserAgent)
}
if t.RoundTripper == nil {
d := &net.Dialer{
Timeout: 2 * time.Second,
KeepAlive: 30 * time.Second,
}
t.RoundTripper = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: d.DialContext,
ForceAttemptHTTP2: false,
MaxIdleConns: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
return t.RoundTripper.RoundTrip(req)
}
var Client = &http.Client{
Transport: &Transport{UserAgent: "sshrimp-agent"},
Timeout: 10 * time.Second,
}

View File

@@ -3,57 +3,27 @@ package identity
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"regexp"
"strings"
"unicode/utf8"
"git.narnian.us/lordwelch/sshrimp/internal/config"
"gitea.narnian.us/lordwelch/sshrimp/internal/config"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/sirupsen/logrus"
)
func init() {
// Disable log prefixes such as the default timestamp.
// Prefix text prevents the message from being parsed as JSON.
// A timestamp is added when shipping logs to Cloud Logging.
log.SetFlags(0)
}
// Entry defines a log entry.
type Entry struct {
Message string `json:"message"`
Severity string `json:"severity,omitempty"`
Trace string `json:"logging.googleapis.com/trace,omitempty"`
// Logs Explorer allows filtering and display of this as `jsonPayload.component`.
Component string `json:"component,omitempty"`
}
// String renders an entry structure to the JSON format expected by Cloud Logging.
func (e Entry) String() string {
if e.Severity == "" {
e.Severity = "INFO"
}
out, err := json.Marshal(e)
if err != nil {
log.Printf("json.Marshal: %v", err)
}
return string(out)
}
// Identity holds information required to verify an OIDC identity token
type Identity struct {
ctx context.Context
verifier *oidc.IDTokenVerifier
usernameREs []*regexp.Regexp
usernameClaims []string
log *logrus.Entry
}
// NewIdentity return a new Identity, with default values and oidc proivder information populated
func NewIdentity(c *config.SSHrimp) (*Identity, error) {
func NewIdentity(log *logrus.Entry, c *config.SSHrimp) (*Identity, error) {
ctx := context.Background()
provider, err := oidc.NewProvider(ctx, c.Agent.ProviderURL)
if err != nil {
@@ -75,6 +45,7 @@ func NewIdentity(c *config.SSHrimp) (*Identity, error) {
verifier: provider.Verifier(oidcConfig),
usernameREs: regexes,
usernameClaims: c.CertificateAuthority.UsernameClaims,
log: log,
}, nil
}
@@ -89,20 +60,17 @@ func (i *Identity) Validate(token string) ([]string, error) {
}
func (i *Identity) getUsernames(idToken *oidc.IDToken) ([]string, error) {
var claims map[string]interface{}
var claims map[string]any
if err := idToken.Claims(&claims); err != nil {
return nil, errors.New("failed to parse claims: " + err.Error())
}
usernames := make([]string, 0, len(i.usernameClaims))
for idx, claim := range i.usernameClaims {
claimedUsernames := getClaim(claim, claims)
claimedUsernames := i.getClaim(claim, claims)
if len(claimedUsernames) == 0 {
log.Println(Entry{
Severity: "NOTICE",
Message: fmt.Sprintf("Did not find a username using: getClaim(%#v, %#v)", claim, claims),
})
i.log.Errorf("Did not find a username using: getClaim(%#v, %#v)", claim, claims)
}
if idx < len(i.usernameREs) {
@@ -114,10 +82,7 @@ func (i *Identity) getUsernames(idToken *oidc.IDToken) ([]string, error) {
}
}
log.Println(Entry{
Severity: "NOTICE",
Message: fmt.Sprintf("Adding usernames: %v", usernames),
})
i.log.Infof("Adding usernames: %v", usernames)
if len(usernames) < 1 {
return nil, errors.New("configured username claim not in identity token")
}
@@ -131,13 +96,13 @@ func parseUsername(username string, re *regexp.Regexp) string {
return ""
}
func getClaim(claim string, claims map[string]interface{}) []string {
func (i *Identity) getClaim(claim string, claims map[string]any) []string {
usernames := make([]string, 0, 2)
parts := strings.Split(claim, ".")
f:
for idx, part := range parts {
switch v := claims[part].(type) {
case map[string]interface{}:
case map[string]any:
claims = v
case []map[string]string:
for _, claimItem := range v {
@@ -147,7 +112,7 @@ f:
}
}
break f
case []interface{}:
case []any:
for _, value := range v {
if name, ok := value.(string); ok {
usernames = append(usernames, name)
@@ -161,21 +126,15 @@ f:
}
}
return base64Decode(usernames)
return i.base64Decode(usernames)
}
func base64Decode(names []string) []string {
func (i *Identity) base64Decode(names []string) []string {
for idx, name := range names {
log.Println(Entry{
Severity: "NOTICE",
Message: fmt.Sprintf("Attempting to decode %q as base64\n", name),
})
i.log.Debugf("Attempting to decode %q as base64\n", name)
decoded, err := base64.RawURLEncoding.DecodeString(name)
if err == nil && utf8.Valid(decoded) {
names[idx] = string(decoded)
log.Println(Entry{
Severity: "NOTICE",
Message: fmt.Sprintf("Successfully decoded %q as base64\n", names[idx]),
})
i.log.Debugf("Successfully decoded %q as base64\n", names[idx])
}
}
return names

View File

@@ -5,19 +5,18 @@ import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"math"
"math/big"
"net/http"
"regexp"
"strings"
"time"
"errors"
"git.narnian.us/lordwelch/sshrimp/internal/config"
"git.narnian.us/lordwelch/sshrimp/internal/identity"
"gitea.narnian.us/lordwelch/sshrimp/internal/config"
"gitea.narnian.us/lordwelch/sshrimp/internal/http"
"gitea.narnian.us/lordwelch/sshrimp/internal/identity"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
@@ -43,11 +42,12 @@ type SSHrimpEvent struct {
// SignCertificateAllURLs iterate through each configured url if there is an error signing the certificate
func SignCertificateAllURLs(publicKey ssh.PublicKey, token string, forceCommand string, urls []string) (*ssh.Certificate, error) {
var (
err error
err = fmt.Errorf("no urls found to sign certificate")
cert *ssh.Certificate
)
// Try each configured url before exiting if there is an error
Log.Logger.Tracef("Attempting to sign cert with urls %v", urls)
for _, url := range urls {
cert, err = SignCertificate(publicKey, token, forceCommand, url)
if err == nil {
@@ -69,12 +69,12 @@ func SignCertificate(publicKey ssh.PublicKey, token string, forceCommand string,
return nil, err
}
var uri string
result, err := http.Post(uri, "application/json", bytes.NewReader(payload))
Log.Logger.Tracef("Posting to url %s", url)
result, err := http.Client.Post(url, "application/json", bytes.NewReader(payload))
if err != nil {
return nil, fmt.Errorf("http post failed: %w", err)
}
Log.Logger.Tracef("Reading body length: %d", result.ContentLength)
resbody, err := io.ReadAll(result.Body)
if err != nil {
return nil, fmt.Errorf("failed to retrieve the response from sshrimp-ca: %w", err)
@@ -82,6 +82,7 @@ func SignCertificate(publicKey ssh.PublicKey, token string, forceCommand string,
// Parse the result form the lambda to extract the certificate
sshrimpResult := SSHrimpResult{}
Log.Logger.Tracef("parsing result: %v", string(resbody))
err = json.Unmarshal(resbody, &sshrimpResult)
if err != nil {
return nil, fmt.Errorf("failed to parse json response from sshrimp-ca: %w: %v", err, string(resbody))
@@ -98,14 +99,14 @@ func SignCertificate(publicKey ssh.PublicKey, token string, forceCommand string,
// Parse the certificate received by sshrimp-ca
cert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(sshrimpResult.Certificate))
Log.Logger.Tracef("parsing cert: %v", err)
if err != nil {
return nil, err
}
return cert.(*ssh.Certificate), nil
}
func ValidateRequest(event SSHrimpEvent, c *config.SSHrimp, requestID string, functionID string) (ssh.Certificate, error) {
func ValidateRequest(log *logrus.Entry, event SSHrimpEvent, c *config.SSHrimp, requestID string, ca ssh.PublicKey) (ssh.Certificate, error) {
// Validate the user supplied public key
publicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(event.PublicKey))
if err != nil {
@@ -113,7 +114,7 @@ func ValidateRequest(event SSHrimpEvent, c *config.SSHrimp, requestID string, fu
}
// Validate the user supplied identity token with the loaded configuration
i, err := identity.NewIdentity(c)
i, err := identity.NewIdentity(log, c)
if err != nil {
return ssh.Certificate{}, err
}
@@ -180,7 +181,7 @@ func ValidateRequest(event SSHrimpEvent, c *config.SSHrimp, requestID string, fu
event.SourceAddress,
event.ForceCommand,
ssh.FingerprintSHA256(publicKey),
functionID,
ssh.FingerprintSHA256(ca),
validBefore.Format("2006/01/02 15:04:05"),
)

View File

@@ -2,6 +2,7 @@ package sshrimpagent
import (
"fmt"
"io"
"net"
"net/http"
"net/url"
@@ -9,7 +10,8 @@ import (
"golang.org/x/crypto/ssh"
"git.narnian.us/lordwelch/sshrimp/internal/config"
"gitea.narnian.us/lordwelch/sshrimp/internal/config"
sshrimp_http "gitea.narnian.us/lordwelch/sshrimp/internal/http"
"github.com/google/uuid"
"github.com/zitadel/oidc/pkg/client/rp"
httphelper "github.com/zitadel/oidc/pkg/http"
@@ -17,9 +19,7 @@ import (
"golang.org/x/exp/slices"
)
var (
key = []byte(uuid.New().String())[:16]
)
var hashKey = []byte(uuid.New().String())[:16]
type OidcClient struct {
ListenAddress string
@@ -38,7 +38,7 @@ func newOIDCClient(c *config.SSHrimp) (*OidcClient, error) {
c.Agent.Scopes = append([]string{"scopes"}, c.Agent.Scopes...)
}
token_chan := make(chan *oidc.Tokens)
token := make(chan *oidc.Tokens)
oidcMux := http.NewServeMux()
return &OidcClient{
@@ -51,7 +51,7 @@ func newOIDCClient(c *config.SSHrimp) (*OidcClient, error) {
WriteTimeout: time.Minute / 2,
IdleTimeout: time.Minute / 2,
},
OIDCToken: token_chan,
OIDCToken: token,
Certificate: &ssh.Certificate{},
SSHrimp: c,
}, nil
@@ -73,7 +73,7 @@ func (o *OidcClient) ListenAndServe() error {
if err = o.setupHandlers(); err != nil {
return err
}
return o.Server.Serve(ln)
return o.Serve(ln)
}
func (o *OidcClient) setupHandlers() error {
@@ -81,12 +81,21 @@ func (o *OidcClient) setupHandlers() error {
redirectURI.Path = "/auth/callback"
successURI := o.baseURI()
successURI.Path = "/success"
var CAKey []byte
resp, err := sshrimp_http.Client.Get(o.Agent.CAUrls[0])
if err == nil && resp.Header.Get("Content-Type") == "text/x-ssh-public-key" {
CAKey, err = io.ReadAll(resp.Body)
if err != nil {
CAKey = []byte{}
}
}
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
cookieHandler := httphelper.NewCookieHandler(hashKey, nil)
options := []rp.Option{
rp.WithCookieHandler(cookieHandler),
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
rp.WithHTTPClient(sshrimp_http.Client),
}
options = append(options, rp.WithPKCE(cookieHandler))
if o.Agent.KeyPath != "" {
@@ -110,19 +119,34 @@ func (o *OidcClient) setupHandlers() error {
o.oidcMux.Handle("/login", rp.AuthURLHandler(state, provider))
o.oidcMux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if o.Certificate != nil && o.Certificate.SignatureKey != nil {
fmt.Fprintf(w, "The SSH CA currently in use is:\n%s", ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey))
Log.Printf("The SSH CA currently in use is:\n%s", ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey))
key := ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey)
if len(CAKey) < 3 {
CAKey = key
}
if !slices.Equal(key, CAKey) {
Log.Errorf("Certificate Authority key has changed from %#v to %#v", string(CAKey), string(key))
fmt.Fprintf(w, "\n\nCertificate Authority key has changed from \n%#v\nto \n%#v", string(CAKey), string(key))
}
}
fmt.Fprintf(w, "The SSH CA currently in use is:\n%s", CAKey)
Log.Printf("The SSH CA currently in use is:\n%s", CAKey)
}))
o.oidcMux.Handle(successURI.Path, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Return to the CLI.")
if o.Certificate != nil && o.Certificate.SignatureKey != nil {
fmt.Fprintf(w, "The SSH CA currently in use is: %s", ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey))
Log.Printf("The SSH CA currently in use is:\n%s", ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey))
key := ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey)
if len(CAKey) < 3 {
CAKey = key
}
if !slices.Equal(key, CAKey) {
Log.Errorf("Certificate Authority key has changed from %#v to %#v", string(CAKey), string(key))
fmt.Fprintf(w, "\n\nCertificate Authority key has changed from \n%#v\nto \n%#v", string(CAKey), string(key))
}
}
fmt.Fprintf(w, "The SSH CA currently in use is: %s", CAKey)
Log.Printf("The SSH CA currently in use is:\n%s", CAKey)
}))
// for demonstration purposes the returned userinfo response is written as JSON object onto response
marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
o.OIDCToken <- tokens
w.Header().Add("location", successURI.String())

View File

@@ -4,11 +4,10 @@ import (
"crypto/rand"
"errors"
"net/http"
"os"
"time"
"git.narnian.us/lordwelch/sshrimp/internal/config"
"git.narnian.us/lordwelch/sshrimp/internal/signer"
"gitea.narnian.us/lordwelch/sshrimp/internal/config"
"gitea.narnian.us/lordwelch/sshrimp/internal/signer"
"github.com/pkg/browser"
"github.com/sirupsen/logrus"
"github.com/zitadel/oidc/pkg/oidc"
@@ -28,17 +27,16 @@ type sshrimpAgent struct {
// NewSSHrimpAgent returns an agent.Agent capable of signing certificates with a SSHrimp Certificate Authority
func NewSSHrimpAgent(c *config.SSHrimp, signer ssh.Signer) (agent.Agent, error) {
oidcClient, err := newOIDCClient(c)
if err != nil {
return nil, err
}
go func() {
if err = oidcClient.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
Log.Logger.Errorf("Server failed: %v", err)
os.Exit(99)
for {
if err = oidcClient.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
Log.Logger.Errorf("Server failed: %v", err)
}
}
}()
@@ -105,7 +103,6 @@ func (r *sshrimpAgent) List() ([]*agent.Key, error) {
Log.Traceln("Certificate has expired")
Log.Traceln("authenticating token")
err := r.authenticate()
if err != nil {
Log.Errorf("authenticating the token failed: %v", err)
return nil, err
@@ -172,6 +169,7 @@ func (r *sshrimpAgent) SignWithFlags(key ssh.PublicKey, data []byte, flags agent
Log.Traceln("signing data")
return r.Sign(key, data)
}
func (r *sshrimpAgent) Extension(extensionType string, contents []byte) ([]byte, error) {
return nil, agent.ErrExtensionUnsupported
}