Compare commits
5 Commits
00de399557
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c570b14537 | ||
|
|
58016c6889 | ||
|
|
8989dc25ac | ||
|
|
899aad07b2 | ||
|
|
2bc73596d3 |
@@ -2,13 +2,12 @@ package main
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -17,13 +16,15 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"git.narnian.us/lordwelch/sshrimp/internal/config"
|
||||
"git.narnian.us/lordwelch/sshrimp/internal/signer"
|
||||
"git.narnian.us/lordwelch/sshrimp/internal/sshrimpagent"
|
||||
"gitea.narnian.us/lordwelch/sshrimp/internal/config"
|
||||
"gitea.narnian.us/lordwelch/sshrimp/internal/signer"
|
||||
"gitea.narnian.us/lordwelch/sshrimp/internal/sshrimpagent"
|
||||
"github.com/prometheus/procfs"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/sirupsen/logrus/hooks/writer"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
"inet.af/peercred"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -38,6 +39,7 @@ type cfg struct {
|
||||
Config string
|
||||
LogDirectory string
|
||||
Verbose bool
|
||||
Foreground bool
|
||||
}
|
||||
|
||||
func getLogDir() string {
|
||||
@@ -72,8 +74,8 @@ func setupLoging(config cfg) error {
|
||||
logrus.ErrorLevel,
|
||||
logrus.WarnLevel,
|
||||
}
|
||||
err := os.MkdirAll(config.LogDirectory, 0750)
|
||||
if err != nil && !os.IsExist(err) {
|
||||
err := os.MkdirAll(config.LogDirectory, 0o750)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -91,7 +93,7 @@ func setupLoging(config cfg) error {
|
||||
Writer: os.Stderr,
|
||||
LogLevels: levels,
|
||||
})
|
||||
logger.Out = ioutil.Discard
|
||||
logger.Out = io.Discard
|
||||
file, err := os.Create(logName)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -122,36 +124,125 @@ func ExpandPath(path string) string {
|
||||
return path
|
||||
}
|
||||
|
||||
func main2(cli cfg, c *config.SSHrimp) {
|
||||
listener := openSocket(ExpandPath(c.Agent.Socket))
|
||||
if listener == nil {
|
||||
log.Errorln("Failed to open socket")
|
||||
return
|
||||
}
|
||||
err := setupLoging(cli)
|
||||
if err != nil {
|
||||
log.Warnf("Error setting up logging: %v", err)
|
||||
}
|
||||
err = launchAgent(c, listener)
|
||||
if err != nil {
|
||||
log.Panic("Failed to launch agent", err)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
defaultConfigPath := "~/.ssh/sshrimp.toml"
|
||||
if configPathFromEnv, ok := os.LookupEnv("SSHRIMP_CONFIG"); ok && configPathFromEnv != "" {
|
||||
defaultConfigPath = configPathFromEnv
|
||||
}
|
||||
var cli cfg
|
||||
var (
|
||||
cli cfg
|
||||
err error
|
||||
)
|
||||
flag.StringVar(&cli.Config, "config", defaultConfigPath, "sshrimp config file")
|
||||
flag.StringVar(&cli.LogDirectory, "log", getLogDir(), "sshrimp log directory")
|
||||
flag.BoolVar(&cli.Verbose, "v", false, "enable verbose logging")
|
||||
fmt.Println(getLogDir())
|
||||
flag.BoolVar(&cli.Foreground, "f", false, "Run in the foreground")
|
||||
|
||||
flag.Parse()
|
||||
sshCommand := flag.Args()
|
||||
if cli.Verbose {
|
||||
logger.SetLevel(logrus.DebugLevel)
|
||||
}
|
||||
|
||||
cfgFile := ExpandPath(cli.Config)
|
||||
cfgFile, err = filepath.Abs(cfgFile)
|
||||
if err != nil {
|
||||
log.Errorln("config must be an absolute path")
|
||||
os.Exit(1)
|
||||
}
|
||||
c := config.NewSSHrimpWithDefaults()
|
||||
err := c.Read(ExpandPath(cli.Config))
|
||||
err = c.Read(cfgFile)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
listener := openSocket(ExpandPath(c.Agent.Socket))
|
||||
if listener == nil {
|
||||
logger.Errorln("Failed to open socket")
|
||||
return
|
||||
if os.Getenv("SSHRIMP_DAEMON") == "true" {
|
||||
cli.Foreground = true
|
||||
}
|
||||
if err := setupLoging(cli); err != nil {
|
||||
logger.Warnf("Error setting up logging: %v", err)
|
||||
}
|
||||
err = launchAgent(c, listener)
|
||||
if cli.Foreground {
|
||||
logger.Println("Launching agent")
|
||||
main2(cli, c)
|
||||
} else {
|
||||
logger.Debug("Attempting to start daemon")
|
||||
var nullFile *os.File
|
||||
nullFile, err = os.Open(os.DevNull)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
env := os.Environ()
|
||||
env = append(env, "SSHRIMP_DAEMON=true")
|
||||
executable, err := os.Executable()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err = os.StartProcess(executable, os.Args, &os.ProcAttr{
|
||||
Dir: filepath.Dir(cfgFile),
|
||||
Env: env,
|
||||
Files: []*os.File{nullFile, nullFile, nullFile},
|
||||
Sys: &syscall.SysProcAttr{
|
||||
// Chroot: d.Chroot,
|
||||
Setsid: true,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
nullFile.Close()
|
||||
logger.Debugf("Agent started in the background check %s for logs", getLogDir())
|
||||
}
|
||||
if len(sshCommand) > 1 && filepath.Base(sshCommand[0]) == "ssh" {
|
||||
syscall.Exec(sshCommand[0], sshCommand, os.Environ())
|
||||
}
|
||||
}
|
||||
|
||||
func socketWorks(path string) bool {
|
||||
var (
|
||||
pid int
|
||||
cred *peercred.Creds
|
||||
)
|
||||
conn, sockErr := net.Dial("unix", path)
|
||||
if sockErr != nil {
|
||||
return false
|
||||
}
|
||||
if conn == nil {
|
||||
return false
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
cred, sockErr = peercred.Get(conn)
|
||||
if sockErr != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
ok bool
|
||||
process *os.Process
|
||||
)
|
||||
pid, ok = cred.PID()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
process, sockErr = os.FindProcess(pid)
|
||||
if sockErr != nil {
|
||||
return false
|
||||
}
|
||||
defer process.Release()
|
||||
return process.Signal(syscall.SIGHUP) == nil
|
||||
}
|
||||
|
||||
func openSocket(socketPath string) net.Listener {
|
||||
@@ -161,39 +252,80 @@ func openSocket(socketPath string) net.Listener {
|
||||
logMessage string
|
||||
)
|
||||
|
||||
if _, err = os.Stat(socketPath); err == nil {
|
||||
fmt.Println("Creating socket")
|
||||
fmt.Printf("File already exists at %s\n", socketPath)
|
||||
conn, sockErr := net.Dial("unix", socketPath)
|
||||
if conn == nil {
|
||||
logMessage = "conn is nil"
|
||||
}
|
||||
if sockErr == nil { // socket is accepting connections
|
||||
conn.Close()
|
||||
fmt.Printf("socket %s already exists\n", socketPath)
|
||||
if socketWorks(socketPath) { // socket is accepting connections
|
||||
log.Printf("socket %s already exists\n", socketPath)
|
||||
return nil
|
||||
}
|
||||
fmt.Printf("Socket is not connected %s\n", logMessage)
|
||||
log.Printf("Socket is not connected %s\n", logMessage)
|
||||
err = os.Remove(socketPath)
|
||||
if err == nil { // socket is not accepting connections, assuming safe to remove
|
||||
fmt.Println("Deleting socket: success")
|
||||
} else {
|
||||
fmt.Println("Deleting socket: fail", err)
|
||||
log.Println("Deleting socket: success")
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
log.Println("Deleting socket: fail", err)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// This affects all files created for the process. Since this is a sensitive
|
||||
// socket, only allow the current user to write to the socket.
|
||||
syscall.Umask(0o077)
|
||||
listener, err = net.Listen("unix", socketPath)
|
||||
if err != nil {
|
||||
fmt.Println("Error opening socket:", err)
|
||||
log.Println("Error opening socket:", err)
|
||||
return nil
|
||||
}
|
||||
log.Println("Opened socket", socketPath)
|
||||
return listener
|
||||
}
|
||||
|
||||
func getConnectedProcess(conn net.Conn) string {
|
||||
var (
|
||||
cred *peercred.Creds
|
||||
err error
|
||||
)
|
||||
cred, err = peercred.Get(conn)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
pid, ok := cred.PID()
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
var (
|
||||
proc procfs.Proc
|
||||
name string
|
||||
)
|
||||
proc, err = procfs.NewProc(pid)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("pid %d", pid)
|
||||
}
|
||||
name, err = proc.Executable()
|
||||
if err == nil {
|
||||
return fmt.Sprintf("pid %d", pid)
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func handle(sshrimpAgent agent.Agent, conn net.Conn) (err error) {
|
||||
defer func() {
|
||||
panicErr := recover()
|
||||
|
||||
if panicErr != nil {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("something panicked: %w: %v", err, panicErr)
|
||||
return
|
||||
}
|
||||
err, _ = panicErr.(error)
|
||||
return
|
||||
}
|
||||
}()
|
||||
log.Infof("Serving agent to %s", getConnectedProcess(conn))
|
||||
if err = agent.ServeAgent(sshrimpAgent, conn); err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Errorf("Error serving agent: %v", err)
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func launchAgent(c *config.SSHrimp, listener net.Listener) error {
|
||||
var (
|
||||
err error
|
||||
@@ -202,11 +334,11 @@ func launchAgent(c *config.SSHrimp, listener net.Listener) error {
|
||||
)
|
||||
defer listener.Close()
|
||||
|
||||
fmt.Printf("listening on %s\n", c.Agent.Socket)
|
||||
log.Printf("listening on %s\n", c.Agent.Socket)
|
||||
|
||||
// Generate a new SSH private/public key pair
|
||||
log.Tracef("Generating RSA %d ssh keys", 2048)
|
||||
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
log.Tracef("Generating ed25519 ssh keys")
|
||||
_, privateKey, err = ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -220,7 +352,7 @@ func launchAgent(c *config.SSHrimp, listener net.Listener) error {
|
||||
log.Traceln("Creating new sshrimp agent from sshSigner and config")
|
||||
sshrimpAgent, err := sshrimpagent.NewSSHrimpAgent(c, sshSigner)
|
||||
if err != nil {
|
||||
log.Logger.Errorf("Failed to create sshrimpAgent: %v", err)
|
||||
log.Errorf("Failed to create sshrimpAgent: %v", err)
|
||||
}
|
||||
|
||||
// Listen for signals so that we can close the listener and exit nicely
|
||||
@@ -230,8 +362,9 @@ func launchAgent(c *config.SSHrimp, listener net.Listener) error {
|
||||
osSignals := make(chan os.Signal, 10)
|
||||
signal.Notify(osSignals, sigExit...)
|
||||
go func() {
|
||||
<-osSignals
|
||||
listener.Close()
|
||||
sig := <-osSignals
|
||||
log.Infof("Recieved signal %v: closing", sig)
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
log.Traceln("Starting main loop")
|
||||
@@ -243,14 +376,10 @@ func launchAgent(c *config.SSHrimp, listener net.Listener) error {
|
||||
log.Errorf("Error accepting connection: %v", err)
|
||||
if strings.Contains(err.Error(), "use of closed network connection") {
|
||||
// Occurs if the user interrupts the agent with a ctrl-c signal
|
||||
return nil
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
log.Traceln("Serving agent")
|
||||
if err = agent.ServeAgent(sshrimpAgent, conn); err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Errorf("Error serving agent: %v", err)
|
||||
return err
|
||||
log.Errorf("Error accepting connection: %v", err)
|
||||
}
|
||||
go handle(sshrimpAgent, conn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
//go:build darwin || linux
|
||||
// +build darwin linux
|
||||
|
||||
package main
|
||||
|
||||
|
||||
@@ -14,14 +14,17 @@ import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.narnian.us/lordwelch/sshrimp/internal/config"
|
||||
"git.narnian.us/lordwelch/sshrimp/internal/signer"
|
||||
"gitea.narnian.us/lordwelch/sshrimp/internal/config"
|
||||
"gitea.narnian.us/lordwelch/sshrimp/internal/signer"
|
||||
"github.com/BurntSushi/toml"
|
||||
gonanoid "github.com/matoous/go-nanoid/v2"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func httpError(w http.ResponseWriter, v interface{}, statusCode int) {
|
||||
func httpError(w http.ResponseWriter, v any, statusCode int) {
|
||||
var b bytes.Buffer
|
||||
e := json.NewEncoder(&b)
|
||||
_ = e.Encode(v)
|
||||
@@ -31,16 +34,22 @@ func httpError(w http.ResponseWriter, v interface{}, statusCode int) {
|
||||
type Server struct {
|
||||
config *config.SSHrimp
|
||||
Key ssh.Signer
|
||||
Log *logrus.Logger
|
||||
}
|
||||
|
||||
func NewServer(cfg *config.SSHrimp) (*Server, error) {
|
||||
server := &Server{
|
||||
config: cfg,
|
||||
Log: logrus.New(),
|
||||
}
|
||||
return server, nil
|
||||
server.Log.SetLevel(logrus.DebugLevel)
|
||||
return server, server.LoadKey()
|
||||
}
|
||||
|
||||
func (s *Server) LoadKey() error {
|
||||
if s.config.CertificateAuthority.KeyPath == "" {
|
||||
return fmt.Errorf("key path missing")
|
||||
}
|
||||
b, err := os.ReadFile(s.config.CertificateAuthority.KeyPath)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
@@ -79,14 +88,18 @@ func (s *Server) GenerateKey() error {
|
||||
|
||||
// ServeHTTP handles a request to sign an SSH public key verified by an OpenIDConnect id_token
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
txid := gonanoid.Must()
|
||||
log := s.Log.WithField("X-Request-ID", txid)
|
||||
w.Header().Add("X-Request-ID", txid)
|
||||
defer r.Body.Close()
|
||||
if r.URL.Path == "/config" {
|
||||
if strings.HasPrefix(r.URL.Path, "/config") {
|
||||
io.Copy(io.Discard, r.Body)
|
||||
new_config := *s.config
|
||||
// new_config.CertificateAuthority
|
||||
newConfig := *s.config
|
||||
newConfig.CertificateAuthority = config.CertificateAuthority{}
|
||||
w.Header().Add("Content-Type", "application/toml")
|
||||
w.Header().Add("Content-Disposition", `attachment; filename="sshrimp.toml"`)
|
||||
t := toml.NewEncoder(w)
|
||||
_ = t.Encode(new_config)
|
||||
_ = t.Encode(newConfig)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Content-Type") != "application/json" {
|
||||
@@ -103,7 +116,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
certificate, err := signer.ValidateRequest(event, s.config, r.Header.Get("Function-Execution-Id"), fmt.Sprintf("%s/%s/%s", os.Getenv("GCP_PROJECT"), os.Getenv("FUNCTION_REGION"), os.Getenv("FUNCTION_NAME")))
|
||||
certificate, err := signer.ValidateRequest(log, event, s.config, txid, s.Key.PublicKey())
|
||||
if err != nil {
|
||||
httpError(w, signer.SSHrimpResult{Certificate: "", ErrorMessage: err.Error(), ErrorType: http.StatusText(http.StatusBadRequest)}, http.StatusBadRequest)
|
||||
return
|
||||
@@ -151,6 +164,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
func main() {
|
||||
cfgFile := flag.String("config", "/etc/sshrimp.toml", "Path to sshrimp.toml")
|
||||
addr := flag.String("addr", "127.0.0.1:8080", "Address to listen on")
|
||||
flag.Parse()
|
||||
cfg := config.NewSSHrimp()
|
||||
if err := cfg.Read(*cfgFile); err != nil {
|
||||
log.Printf("Unable to read config file %s: %v", *cfgFile, err)
|
||||
|
||||
9
go.mod
9
go.mod
@@ -1,4 +1,4 @@
|
||||
module git.narnian.us/lordwelch/sshrimp
|
||||
module gitea.narnian.us/lordwelch/sshrimp
|
||||
|
||||
go 1.24.0
|
||||
|
||||
@@ -8,19 +8,24 @@ require (
|
||||
github.com/BurntSushi/toml v1.6.0
|
||||
github.com/coreos/go-oidc/v3 v3.17.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/matoous/go-nanoid/v2 v2.1.0
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
|
||||
github.com/prometheus/procfs v0.19.2
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/zitadel/oidc v1.13.5
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93
|
||||
inet.af/peercred v0.0.0-20210906144145-0893ea02156a
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/go-jose/go-jose/v4 v4.1.3 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/gorilla/schema v1.4.1 // indirect
|
||||
github.com/gorilla/securecookie v1.1.2 // indirect
|
||||
github.com/jeremija/gosubmit v0.2.8 // indirect
|
||||
github.com/rs/cors v1.11.1 // indirect
|
||||
github.com/stretchr/testify v1.11.1 // indirect
|
||||
github.com/zitadel/logging v0.6.2 // indirect
|
||||
golang.org/x/net v0.48.0 // indirect
|
||||
golang.org/x/oauth2 v0.34.0 // indirect
|
||||
golang.org/x/sys v0.40.0 // indirect
|
||||
|
||||
23
go.sum
23
go.sum
@@ -19,38 +19,47 @@ github.com/gorilla/schema v1.4.1 h1:jUg5hUjCSDZpNGLuXQOgIWGdlgrIdYvgQ0wZtdK1M3E=
|
||||
github.com/gorilla/schema v1.4.1/go.mod h1:Dg5SSm5PV60mhF2NFaTV1xuYYj8tV8NOPRo4FggUMnM=
|
||||
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
|
||||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/jeremija/gosubmit v0.2.7 h1:At0OhGCFGPXyjPYAsCchoBUhE099pcBXmsb4iZqROIc=
|
||||
github.com/jeremija/gosubmit v0.2.7/go.mod h1:Ui+HS073lCFREXBbdfrJzMB57OI/bdxTiLtrDHHhFPI=
|
||||
github.com/jeremija/gosubmit v0.2.8 h1:mmSITBz9JxVtu8eqbN+zmmwX7Ij2RidQxhcwRVI4wqA=
|
||||
github.com/jeremija/gosubmit v0.2.8/go.mod h1:Ui+HS073lCFREXBbdfrJzMB57OI/bdxTiLtrDHHhFPI=
|
||||
github.com/matoous/go-nanoid/v2 v2.1.0 h1:P64+dmq21hhWdtvZfEAofnvJULaRR1Yib0+PnU669bE=
|
||||
github.com/matoous/go-nanoid/v2 v2.1.0/go.mod h1:KlbGNQ+FhrUNIHUxZdL63t7tl4LaPkZNpUULS8H4uVM=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rs/cors v1.10.1 h1:L0uuZVXIKlI1SShY2nhFfo44TYvDPQ1w4oFkUJNfhyo=
|
||||
github.com/rs/cors v1.10.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
|
||||
github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws=
|
||||
github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw=
|
||||
github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA=
|
||||
github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/zitadel/logging v0.3.4 h1:9hZsTjMMTE3X2LUi0xcF9Q9EdLo+FAezeu52ireBbHM=
|
||||
github.com/zitadel/logging v0.3.4/go.mod h1:aPpLQhE+v6ocNK0TWrBrd363hZ95KcI17Q1ixAQwZF0=
|
||||
github.com/zitadel/logging v0.6.2 h1:MW2kDDR0ieQynPZ0KIZPrh9ote2WkxfBif5QoARDQcU=
|
||||
github.com/zitadel/logging v0.6.2/go.mod h1:z6VWLWUkJpnNVDSLzrPSQSQyttysKZ6bCRongw0ROK4=
|
||||
github.com/zitadel/oidc v1.13.5 h1:7jhh68NGZitLqwLiVU9Dtwa4IraJPFF1vS+4UupO93U=
|
||||
github.com/zitadel/oidc v1.13.5/go.mod h1:rHs1DhU3Sv3tnI6bQRVlFa3u0lCwtR7S21WHY+yXgPA=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
|
||||
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 h1:fQsdNF2N+/YewlRZiricy4P1iimyPKZ/xwniHj8Q2a0=
|
||||
golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93/go.mod h1:EPRbTFwzwjXj9NpYyyrvenVh9Y+GFeEvMNh7Xuz7xgU=
|
||||
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
|
||||
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
|
||||
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
|
||||
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20210301091718-77cc2087c03b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
|
||||
golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
|
||||
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
|
||||
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
@@ -59,3 +68,5 @@ gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
inet.af/peercred v0.0.0-20210906144145-0893ea02156a h1:qdkS8Q5/i10xU2ArJMKYhVa1DORzBfYS/qA2UK2jheg=
|
||||
inet.af/peercred v0.0.0-20210906144145-0893ea02156a/go.mod h1:FjawnflS/udxX+SvpsMgZfdqx2aykOlkISeAsADi5IU=
|
||||
|
||||
@@ -88,7 +88,7 @@ func NewSSHrimpWithDefaults() *SSHrimp {
|
||||
return &sshrimp
|
||||
}
|
||||
|
||||
func validateInt(val interface{}) error {
|
||||
func validateInt(val any) error {
|
||||
if str, ok := val.(string); ok {
|
||||
if _, err := strconv.Atoi(str); err != nil {
|
||||
return err
|
||||
@@ -100,7 +100,7 @@ func validateInt(val interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateURL(val interface{}) error {
|
||||
func validateURL(val any) error {
|
||||
if str, ok := val.(string); ok {
|
||||
if _, err := url.ParseRequestURI(str); err != nil {
|
||||
return err
|
||||
@@ -112,7 +112,7 @@ func validateURL(val interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDuration(val interface{}) error {
|
||||
func validateDuration(val any) error {
|
||||
if str, ok := val.(string); ok {
|
||||
if _, err := time.ParseDuration(str); err != nil {
|
||||
return err
|
||||
@@ -124,7 +124,7 @@ func validateDuration(val interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateAlias(val interface{}) error {
|
||||
func validateAlias(val any) error {
|
||||
if str, ok := val.(string); ok {
|
||||
if !strings.HasPrefix(str, "alias/") {
|
||||
return errors.New("KMS alias must begin with alias/")
|
||||
|
||||
39
internal/http/client.go
Normal file
39
internal/http/client.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Transport struct {
|
||||
http.RoundTripper
|
||||
UserAgent string
|
||||
}
|
||||
|
||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if req.Header.Get("User-Agent") == "" {
|
||||
req.Header.Set("User-Agent", t.UserAgent)
|
||||
}
|
||||
if t.RoundTripper == nil {
|
||||
d := &net.Dialer{
|
||||
Timeout: 2 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
t.RoundTripper = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: d.DialContext,
|
||||
ForceAttemptHTTP2: false,
|
||||
MaxIdleConns: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 2 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
}
|
||||
return t.RoundTripper.RoundTrip(req)
|
||||
}
|
||||
|
||||
var Client = &http.Client{
|
||||
Transport: &Transport{UserAgent: "sshrimp-agent"},
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
@@ -3,57 +3,27 @@ package identity
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
"git.narnian.us/lordwelch/sshrimp/internal/config"
|
||||
"gitea.narnian.us/lordwelch/sshrimp/internal/config"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Disable log prefixes such as the default timestamp.
|
||||
// Prefix text prevents the message from being parsed as JSON.
|
||||
// A timestamp is added when shipping logs to Cloud Logging.
|
||||
log.SetFlags(0)
|
||||
}
|
||||
|
||||
// Entry defines a log entry.
|
||||
type Entry struct {
|
||||
Message string `json:"message"`
|
||||
Severity string `json:"severity,omitempty"`
|
||||
Trace string `json:"logging.googleapis.com/trace,omitempty"`
|
||||
|
||||
// Logs Explorer allows filtering and display of this as `jsonPayload.component`.
|
||||
Component string `json:"component,omitempty"`
|
||||
}
|
||||
|
||||
// String renders an entry structure to the JSON format expected by Cloud Logging.
|
||||
func (e Entry) String() string {
|
||||
if e.Severity == "" {
|
||||
e.Severity = "INFO"
|
||||
}
|
||||
out, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
log.Printf("json.Marshal: %v", err)
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
|
||||
// Identity holds information required to verify an OIDC identity token
|
||||
type Identity struct {
|
||||
ctx context.Context
|
||||
verifier *oidc.IDTokenVerifier
|
||||
usernameREs []*regexp.Regexp
|
||||
usernameClaims []string
|
||||
log *logrus.Entry
|
||||
}
|
||||
|
||||
// NewIdentity return a new Identity, with default values and oidc proivder information populated
|
||||
func NewIdentity(c *config.SSHrimp) (*Identity, error) {
|
||||
func NewIdentity(log *logrus.Entry, c *config.SSHrimp) (*Identity, error) {
|
||||
ctx := context.Background()
|
||||
provider, err := oidc.NewProvider(ctx, c.Agent.ProviderURL)
|
||||
if err != nil {
|
||||
@@ -75,6 +45,7 @@ func NewIdentity(c *config.SSHrimp) (*Identity, error) {
|
||||
verifier: provider.Verifier(oidcConfig),
|
||||
usernameREs: regexes,
|
||||
usernameClaims: c.CertificateAuthority.UsernameClaims,
|
||||
log: log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -89,20 +60,17 @@ func (i *Identity) Validate(token string) ([]string, error) {
|
||||
}
|
||||
|
||||
func (i *Identity) getUsernames(idToken *oidc.IDToken) ([]string, error) {
|
||||
var claims map[string]interface{}
|
||||
var claims map[string]any
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return nil, errors.New("failed to parse claims: " + err.Error())
|
||||
}
|
||||
usernames := make([]string, 0, len(i.usernameClaims))
|
||||
for idx, claim := range i.usernameClaims {
|
||||
|
||||
claimedUsernames := getClaim(claim, claims)
|
||||
claimedUsernames := i.getClaim(claim, claims)
|
||||
|
||||
if len(claimedUsernames) == 0 {
|
||||
log.Println(Entry{
|
||||
Severity: "NOTICE",
|
||||
Message: fmt.Sprintf("Did not find a username using: getClaim(%#v, %#v)", claim, claims),
|
||||
})
|
||||
i.log.Errorf("Did not find a username using: getClaim(%#v, %#v)", claim, claims)
|
||||
}
|
||||
|
||||
if idx < len(i.usernameREs) {
|
||||
@@ -114,10 +82,7 @@ func (i *Identity) getUsernames(idToken *oidc.IDToken) ([]string, error) {
|
||||
|
||||
}
|
||||
}
|
||||
log.Println(Entry{
|
||||
Severity: "NOTICE",
|
||||
Message: fmt.Sprintf("Adding usernames: %v", usernames),
|
||||
})
|
||||
i.log.Infof("Adding usernames: %v", usernames)
|
||||
if len(usernames) < 1 {
|
||||
return nil, errors.New("configured username claim not in identity token")
|
||||
}
|
||||
@@ -131,13 +96,13 @@ func parseUsername(username string, re *regexp.Regexp) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func getClaim(claim string, claims map[string]interface{}) []string {
|
||||
func (i *Identity) getClaim(claim string, claims map[string]any) []string {
|
||||
usernames := make([]string, 0, 2)
|
||||
parts := strings.Split(claim, ".")
|
||||
f:
|
||||
for idx, part := range parts {
|
||||
switch v := claims[part].(type) {
|
||||
case map[string]interface{}:
|
||||
case map[string]any:
|
||||
claims = v
|
||||
case []map[string]string:
|
||||
for _, claimItem := range v {
|
||||
@@ -147,7 +112,7 @@ f:
|
||||
}
|
||||
}
|
||||
break f
|
||||
case []interface{}:
|
||||
case []any:
|
||||
for _, value := range v {
|
||||
if name, ok := value.(string); ok {
|
||||
usernames = append(usernames, name)
|
||||
@@ -161,21 +126,15 @@ f:
|
||||
}
|
||||
|
||||
}
|
||||
return base64Decode(usernames)
|
||||
return i.base64Decode(usernames)
|
||||
}
|
||||
func base64Decode(names []string) []string {
|
||||
func (i *Identity) base64Decode(names []string) []string {
|
||||
for idx, name := range names {
|
||||
log.Println(Entry{
|
||||
Severity: "NOTICE",
|
||||
Message: fmt.Sprintf("Attempting to decode %q as base64\n", name),
|
||||
})
|
||||
i.log.Debugf("Attempting to decode %q as base64\n", name)
|
||||
decoded, err := base64.RawURLEncoding.DecodeString(name)
|
||||
if err == nil && utf8.Valid(decoded) {
|
||||
names[idx] = string(decoded)
|
||||
log.Println(Entry{
|
||||
Severity: "NOTICE",
|
||||
Message: fmt.Sprintf("Successfully decoded %q as base64\n", names[idx]),
|
||||
})
|
||||
i.log.Debugf("Successfully decoded %q as base64\n", names[idx])
|
||||
}
|
||||
}
|
||||
return names
|
||||
|
||||
@@ -5,19 +5,18 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"errors"
|
||||
|
||||
"git.narnian.us/lordwelch/sshrimp/internal/config"
|
||||
"git.narnian.us/lordwelch/sshrimp/internal/identity"
|
||||
"gitea.narnian.us/lordwelch/sshrimp/internal/config"
|
||||
"gitea.narnian.us/lordwelch/sshrimp/internal/http"
|
||||
"gitea.narnian.us/lordwelch/sshrimp/internal/identity"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
@@ -43,11 +42,12 @@ type SSHrimpEvent struct {
|
||||
// SignCertificateAllURLs iterate through each configured url if there is an error signing the certificate
|
||||
func SignCertificateAllURLs(publicKey ssh.PublicKey, token string, forceCommand string, urls []string) (*ssh.Certificate, error) {
|
||||
var (
|
||||
err error
|
||||
err = fmt.Errorf("no urls found to sign certificate")
|
||||
cert *ssh.Certificate
|
||||
)
|
||||
|
||||
// Try each configured url before exiting if there is an error
|
||||
Log.Logger.Tracef("Attempting to sign cert with urls %v", urls)
|
||||
for _, url := range urls {
|
||||
cert, err = SignCertificate(publicKey, token, forceCommand, url)
|
||||
if err == nil {
|
||||
@@ -69,12 +69,12 @@ func SignCertificate(publicKey ssh.PublicKey, token string, forceCommand string,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var uri string
|
||||
|
||||
result, err := http.Post(uri, "application/json", bytes.NewReader(payload))
|
||||
Log.Logger.Tracef("Posting to url %s", url)
|
||||
result, err := http.Client.Post(url, "application/json", bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("http post failed: %w", err)
|
||||
}
|
||||
Log.Logger.Tracef("Reading body length: %d", result.ContentLength)
|
||||
resbody, err := io.ReadAll(result.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to retrieve the response from sshrimp-ca: %w", err)
|
||||
@@ -82,6 +82,7 @@ func SignCertificate(publicKey ssh.PublicKey, token string, forceCommand string,
|
||||
|
||||
// Parse the result form the lambda to extract the certificate
|
||||
sshrimpResult := SSHrimpResult{}
|
||||
Log.Logger.Tracef("parsing result: %v", string(resbody))
|
||||
err = json.Unmarshal(resbody, &sshrimpResult)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse json response from sshrimp-ca: %w: %v", err, string(resbody))
|
||||
@@ -98,14 +99,14 @@ func SignCertificate(publicKey ssh.PublicKey, token string, forceCommand string,
|
||||
|
||||
// Parse the certificate received by sshrimp-ca
|
||||
cert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(sshrimpResult.Certificate))
|
||||
Log.Logger.Tracef("parsing cert: %v", err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cert.(*ssh.Certificate), nil
|
||||
}
|
||||
|
||||
|
||||
func ValidateRequest(event SSHrimpEvent, c *config.SSHrimp, requestID string, functionID string) (ssh.Certificate, error) {
|
||||
func ValidateRequest(log *logrus.Entry, event SSHrimpEvent, c *config.SSHrimp, requestID string, ca ssh.PublicKey) (ssh.Certificate, error) {
|
||||
// Validate the user supplied public key
|
||||
publicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(event.PublicKey))
|
||||
if err != nil {
|
||||
@@ -113,7 +114,7 @@ func ValidateRequest(event SSHrimpEvent, c *config.SSHrimp, requestID string, fu
|
||||
}
|
||||
|
||||
// Validate the user supplied identity token with the loaded configuration
|
||||
i, err := identity.NewIdentity(c)
|
||||
i, err := identity.NewIdentity(log, c)
|
||||
if err != nil {
|
||||
return ssh.Certificate{}, err
|
||||
}
|
||||
@@ -180,7 +181,7 @@ func ValidateRequest(event SSHrimpEvent, c *config.SSHrimp, requestID string, fu
|
||||
event.SourceAddress,
|
||||
event.ForceCommand,
|
||||
ssh.FingerprintSHA256(publicKey),
|
||||
functionID,
|
||||
ssh.FingerprintSHA256(ca),
|
||||
validBefore.Format("2006/01/02 15:04:05"),
|
||||
)
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package sshrimpagent
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -9,7 +10,8 @@ import (
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"git.narnian.us/lordwelch/sshrimp/internal/config"
|
||||
"gitea.narnian.us/lordwelch/sshrimp/internal/config"
|
||||
sshrimp_http "gitea.narnian.us/lordwelch/sshrimp/internal/http"
|
||||
"github.com/google/uuid"
|
||||
"github.com/zitadel/oidc/pkg/client/rp"
|
||||
httphelper "github.com/zitadel/oidc/pkg/http"
|
||||
@@ -17,9 +19,7 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
var (
|
||||
key = []byte(uuid.New().String())[:16]
|
||||
)
|
||||
var hashKey = []byte(uuid.New().String())[:16]
|
||||
|
||||
type OidcClient struct {
|
||||
ListenAddress string
|
||||
@@ -38,7 +38,7 @@ func newOIDCClient(c *config.SSHrimp) (*OidcClient, error) {
|
||||
c.Agent.Scopes = append([]string{"scopes"}, c.Agent.Scopes...)
|
||||
}
|
||||
|
||||
token_chan := make(chan *oidc.Tokens)
|
||||
token := make(chan *oidc.Tokens)
|
||||
|
||||
oidcMux := http.NewServeMux()
|
||||
return &OidcClient{
|
||||
@@ -51,7 +51,7 @@ func newOIDCClient(c *config.SSHrimp) (*OidcClient, error) {
|
||||
WriteTimeout: time.Minute / 2,
|
||||
IdleTimeout: time.Minute / 2,
|
||||
},
|
||||
OIDCToken: token_chan,
|
||||
OIDCToken: token,
|
||||
Certificate: &ssh.Certificate{},
|
||||
SSHrimp: c,
|
||||
}, nil
|
||||
@@ -73,7 +73,7 @@ func (o *OidcClient) ListenAndServe() error {
|
||||
if err = o.setupHandlers(); err != nil {
|
||||
return err
|
||||
}
|
||||
return o.Server.Serve(ln)
|
||||
return o.Serve(ln)
|
||||
}
|
||||
|
||||
func (o *OidcClient) setupHandlers() error {
|
||||
@@ -81,12 +81,21 @@ func (o *OidcClient) setupHandlers() error {
|
||||
redirectURI.Path = "/auth/callback"
|
||||
successURI := o.baseURI()
|
||||
successURI.Path = "/success"
|
||||
var CAKey []byte
|
||||
resp, err := sshrimp_http.Client.Get(o.Agent.CAUrls[0])
|
||||
if err == nil && resp.Header.Get("Content-Type") == "text/x-ssh-public-key" {
|
||||
CAKey, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
CAKey = []byte{}
|
||||
}
|
||||
}
|
||||
|
||||
cookieHandler := httphelper.NewCookieHandler(key, key, httphelper.WithUnsecure())
|
||||
cookieHandler := httphelper.NewCookieHandler(hashKey, nil)
|
||||
|
||||
options := []rp.Option{
|
||||
rp.WithCookieHandler(cookieHandler),
|
||||
rp.WithVerifierOpts(rp.WithIssuedAtOffset(5 * time.Second)),
|
||||
rp.WithHTTPClient(sshrimp_http.Client),
|
||||
}
|
||||
options = append(options, rp.WithPKCE(cookieHandler))
|
||||
if o.Agent.KeyPath != "" {
|
||||
@@ -110,19 +119,34 @@ func (o *OidcClient) setupHandlers() error {
|
||||
o.oidcMux.Handle("/login", rp.AuthURLHandler(state, provider))
|
||||
o.oidcMux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if o.Certificate != nil && o.Certificate.SignatureKey != nil {
|
||||
fmt.Fprintf(w, "The SSH CA currently in use is:\n%s", ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey))
|
||||
Log.Printf("The SSH CA currently in use is:\n%s", ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey))
|
||||
key := ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey)
|
||||
if len(CAKey) < 3 {
|
||||
CAKey = key
|
||||
}
|
||||
if !slices.Equal(key, CAKey) {
|
||||
Log.Errorf("Certificate Authority key has changed from %#v to %#v", string(CAKey), string(key))
|
||||
fmt.Fprintf(w, "\n\nCertificate Authority key has changed from \n%#v\nto \n%#v", string(CAKey), string(key))
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(w, "The SSH CA currently in use is:\n%s", CAKey)
|
||||
Log.Printf("The SSH CA currently in use is:\n%s", CAKey)
|
||||
}))
|
||||
o.oidcMux.Handle(successURI.Path, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Return to the CLI.")
|
||||
if o.Certificate != nil && o.Certificate.SignatureKey != nil {
|
||||
fmt.Fprintf(w, "The SSH CA currently in use is: %s", ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey))
|
||||
Log.Printf("The SSH CA currently in use is:\n%s", ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey))
|
||||
key := ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey)
|
||||
if len(CAKey) < 3 {
|
||||
CAKey = key
|
||||
}
|
||||
if !slices.Equal(key, CAKey) {
|
||||
Log.Errorf("Certificate Authority key has changed from %#v to %#v", string(CAKey), string(key))
|
||||
fmt.Fprintf(w, "\n\nCertificate Authority key has changed from \n%#v\nto \n%#v", string(CAKey), string(key))
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(w, "The SSH CA currently in use is: %s", CAKey)
|
||||
Log.Printf("The SSH CA currently in use is:\n%s", CAKey)
|
||||
}))
|
||||
|
||||
// for demonstration purposes the returned userinfo response is written as JSON object onto response
|
||||
marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
|
||||
o.OIDCToken <- tokens
|
||||
w.Header().Add("location", successURI.String())
|
||||
|
||||
@@ -4,11 +4,10 @@ import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.narnian.us/lordwelch/sshrimp/internal/config"
|
||||
"git.narnian.us/lordwelch/sshrimp/internal/signer"
|
||||
"gitea.narnian.us/lordwelch/sshrimp/internal/config"
|
||||
"gitea.narnian.us/lordwelch/sshrimp/internal/signer"
|
||||
"github.com/pkg/browser"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/zitadel/oidc/pkg/oidc"
|
||||
@@ -28,17 +27,16 @@ type sshrimpAgent struct {
|
||||
|
||||
// NewSSHrimpAgent returns an agent.Agent capable of signing certificates with a SSHrimp Certificate Authority
|
||||
func NewSSHrimpAgent(c *config.SSHrimp, signer ssh.Signer) (agent.Agent, error) {
|
||||
|
||||
oidcClient, err := newOIDCClient(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go func() {
|
||||
|
||||
for {
|
||||
if err = oidcClient.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
|
||||
Log.Logger.Errorf("Server failed: %v", err)
|
||||
os.Exit(99)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -105,7 +103,6 @@ func (r *sshrimpAgent) List() ([]*agent.Key, error) {
|
||||
Log.Traceln("Certificate has expired")
|
||||
Log.Traceln("authenticating token")
|
||||
err := r.authenticate()
|
||||
|
||||
if err != nil {
|
||||
Log.Errorf("authenticating the token failed: %v", err)
|
||||
return nil, err
|
||||
@@ -172,6 +169,7 @@ func (r *sshrimpAgent) SignWithFlags(key ssh.PublicKey, data []byte, flags agent
|
||||
Log.Traceln("signing data")
|
||||
return r.Sign(key, data)
|
||||
}
|
||||
|
||||
func (r *sshrimpAgent) Extension(extensionType string, contents []byte) ([]byte, error) {
|
||||
return nil, agent.ErrExtensionUnsupported
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user