From 0276f52b49d1be9710884db6d1a40e110709947c Mon Sep 17 00:00:00 2001
From: Timmy Welch <timmy@narnian.us>
Date: Sat, 14 Jan 2023 18:30:28 -0800
Subject: [PATCH] Show current CA key on login

---
 internal/sshrimpagent/auth.go         | 33 +++++++++++++++++++--------
 internal/sshrimpagent/sshrimpagent.go |  4 +++-
 2 files changed, 26 insertions(+), 11 deletions(-)

diff --git a/internal/sshrimpagent/auth.go b/internal/sshrimpagent/auth.go
index 6e4b924..2755724 100644
--- a/internal/sshrimpagent/auth.go
+++ b/internal/sshrimpagent/auth.go
@@ -2,6 +2,7 @@ package sshrimpagent
 
 import (
 	"fmt"
+	"golang.org/x/crypto/ssh"
 	"net"
 	"net/http"
 	"net/url"
@@ -22,12 +23,13 @@ var (
 type OidcClient struct {
 	ListenAddress string
 	*http.Server
-	oidcMux   *http.ServeMux
-	OIDCToken chan *oidc.Tokens
+	oidcMux     *http.ServeMux
+	OIDCToken   chan *oidc.Tokens
+	Certificate *ssh.Certificate
 	*config.SSHrimp
 }
 
-func newOIDCClient(c *config.SSHrimp) (OidcClient, error) {
+func newOIDCClient(c *config.SSHrimp) (*OidcClient, error) {
 	if len(c.Agent.Scopes) < 1 {
 		c.Agent.Scopes = []string{"openid", "email", "profile"}
 	}
@@ -38,7 +40,7 @@ func newOIDCClient(c *config.SSHrimp) (OidcClient, error) {
 	token_chan := make(chan *oidc.Tokens)
 
 	oidcMux := http.NewServeMux()
-	return OidcClient{
+	return &OidcClient{
 		oidcMux: oidcMux,
 		Server: &http.Server{
 			Addr:              fmt.Sprintf("localhost:%d", c.Agent.Port),
@@ -48,8 +50,9 @@ func newOIDCClient(c *config.SSHrimp) (OidcClient, error) {
 			WriteTimeout:      time.Minute / 2,
 			IdleTimeout:       time.Minute / 2,
 		},
-		OIDCToken: token_chan,
-		SSHrimp:   c,
+		OIDCToken:   token_chan,
+		Certificate: &ssh.Certificate{},
+		SSHrimp:     c,
 	}, nil
 }
 
@@ -76,7 +79,7 @@ func (o *OidcClient) setupHandlers() error {
 	redirectURI := o.baseURI()
 	redirectURI.Path = "/auth/callback"
 	successURI := o.baseURI()
-	successURI.RawQuery = url.Values{"auth": []string{"success"}}.Encode()
+	successURI.Path = "/success"
 	// failURI := o.baseURI()
 	// failURI.RawQuery = url.Values{"auth":[]string{"fail"}}.Encode()
 
@@ -108,9 +111,19 @@ func (o *OidcClient) setupHandlers() error {
 	// the AuthURLHandler creates the auth request and redirects the user to the auth server
 	// including state handling with secure cookie and the possibility to use PKCE
 	o.oidcMux.Handle("/login", rp.AuthURLHandler(state, provider))
-	// o.oidcMux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-	// 	fmt.Fprintln(w, "Return to the CLI.")
-	// }))
+	o.oidcMux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		if o.Certificate != nil && o.Certificate.SignatureKey != nil {
+			fmt.Fprintf(w, "The SSH CA currently in use is:\n%s", ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey))
+			Log.Printf("The SSH CA currently in use is:\n%s", ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey))
+		}
+	}))
+	o.oidcMux.Handle(successURI.Path, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		fmt.Fprintln(w, "Return to the CLI.")
+		if o.Certificate != nil && o.Certificate.SignatureKey != nil {
+			fmt.Fprintf(w, "The SSH CA currently in use is: %s", ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey))
+			Log.Printf("The SSH CA currently in use is:\n%s", ssh.MarshalAuthorizedKey(o.Certificate.SignatureKey))
+		}
+	}))
 
 	// for demonstration purposes the returned userinfo response is written as JSON object onto response
 	marshalUserinfo := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens, state string, rp rp.RelyingParty) {
diff --git a/internal/sshrimpagent/sshrimpagent.go b/internal/sshrimpagent/sshrimpagent.go
index e39ccd2..b8e5da4 100644
--- a/internal/sshrimpagent/sshrimpagent.go
+++ b/internal/sshrimpagent/sshrimpagent.go
@@ -19,7 +19,7 @@ import (
 var Log *logrus.Entry
 
 type sshrimpAgent struct {
-	oidcClient  OidcClient
+	oidcClient  *OidcClient
 	signer      ssh.Signer
 	certificate *ssh.Certificate
 	token       *oidc.Tokens
@@ -76,6 +76,7 @@ func (r *sshrimpAgent) authenticate() error {
 func (r *sshrimpAgent) RemoveAll() error {
 	Log.Debugln("Removing identity token and certificate")
 	r.certificate = &ssh.Certificate{}
+	r.oidcClient.Certificate = r.certificate
 	r.token = nil
 	return nil
 }
@@ -117,6 +118,7 @@ func (r *sshrimpAgent) List() ([]*agent.Key, error) {
 			return nil, err
 		}
 		r.certificate = cert
+		r.oidcClient.Certificate = r.certificate
 	}
 
 	var ids []*agent.Key