Add compatibility with Zitadel

Expand ~ to HOME in Agent.Socket
Add url override for gcloud functions v2
Add logging for parsing the principals
go fmt
This commit is contained in:
Timmy Welch 2023-01-14 10:49:39 -08:00
parent bcb5789044
commit a9a40622ca
No known key found for this signature in database
10 changed files with 94 additions and 30 deletions

View File

@ -28,7 +28,7 @@ import (
var (
sigExit = []os.Signal{os.Kill, os.Interrupt}
sigIgnore = []os.Signal{}
sigIgnore []os.Signal
logger = logrus.New()
log *logrus.Entry
appname = "sshrimp"
@ -111,6 +111,18 @@ func setupLoging(config cfg) error {
return nil
}
func ExpandPath(path string) string {
home, err := os.UserHomeDir()
if err != nil {
return path
}
if path[0] == '~' {
path = filepath.Join(home, path[1:])
}
return path
}
func main() {
var cli cfg
flag.StringVar(&cli.Config, "config", config.GetPath(), "sshrimp config file")
@ -140,14 +152,14 @@ func launchAgent(c *config.SSHrimp) error {
err error
listener net.Listener
privateKey crypto.Signer
signer ssh.Signer
sshSigner ssh.Signer
logMessage string
)
log.Traceln("Creating socket")
if _, err = os.Stat(c.Agent.Socket); err == nil {
if _, err = os.Stat(ExpandPath(c.Agent.Socket)); err == nil {
log.Tracef("File already exists at %s", c.Agent.Socket)
conn, sockErr := net.Dial("unix", c.Agent.Socket)
conn, sockErr := net.Dial("unix", ExpandPath(c.Agent.Socket))
if conn == nil {
logMessage = "conn is nil"
}
@ -168,7 +180,7 @@ func launchAgent(c *config.SSHrimp) error {
// This affects all files created for the process. Since this is a sensitive
// socket, only allow the current user to write to the socket.
syscall.Umask(0077)
listener, err = net.Listen("unix", c.Agent.Socket)
listener, err = net.Listen("unix", ExpandPath(c.Agent.Socket))
if err != nil {
return err
}
@ -182,15 +194,15 @@ func launchAgent(c *config.SSHrimp) error {
if err != nil {
return err
}
log.Traceln("Creating new signer from key")
signer, err = ssh.NewSignerFromKey(privateKey)
log.Traceln("Creating new sshSigner from key")
sshSigner, err = ssh.NewSignerFromKey(privateKey)
if err != nil {
return err
}
// Create the sshrimp agent with our configuration and the private key signer
log.Traceln("Creating new sshrimp agent from signer and config")
sshrimpAgent, err := sshrimpagent.NewSSHrimpAgent(c, signer)
// Create the sshrimp agent with our configuration and the private key sshSigner
log.Traceln("Creating new sshrimp agent from sshSigner and config")
sshrimpAgent, err := sshrimpagent.NewSSHrimpAgent(c, sshSigner)
if err != nil {
log.Logger.Errorf("Failed to create sshrimpAgent: %v", err)
}

View File

@ -35,7 +35,7 @@ func HandleRequest(ctx context.Context, event signer.SSHrimpEvent) (*signer.SSHr
// Setup our Certificate Authority signer backed by KMS
kmsSigner := signer.NewAWSSigner(c.CertificateAuthority.KeyAlias)
sshAlgorithmSigner, err := signer.NewAlgorithmSignerFromSigner(kmsSigner, ssh.SigAlgoRSASHA2256)
sshAlgorithmSigner, err := signer.NewAlgorithmSignerFromSigner(kmsSigner, ssh.KeyAlgoRSASHA256)
if err != nil {
return nil, err
}

View File

@ -44,7 +44,7 @@ func SSHrimp(w http.ResponseWriter, r *http.Request) {
// Setup our Certificate Authority signer backed by KMS
kmsSigner := signer.NewGCPSSigner(c.CertificateAuthority.KeyAlias)
sshAlgorithmSigner, err := signer.NewAlgorithmSignerFromSigner(kmsSigner, ssh.SigAlgoRSASHA2256)
sshAlgorithmSigner, err := signer.NewAlgorithmSignerFromSigner(kmsSigner, ssh.KeyAlgoRSASHA256)
if err != nil {
httpError(w, signer.SSHrimpResult{Certificate: "", ErrorMessage: err.Error(), ErrorType: http.StatusText(http.StatusBadRequest)}, http.StatusBadRequest)
return

View File

@ -26,6 +26,7 @@ type Agent struct {
Scopes []string
KeyPath string
Port int
Url string
}
// CertificateAuthority config for the sshrimp-ca lambda
@ -107,7 +108,7 @@ func NewSSHrimpWithDefaults() *SSHrimp {
sshrimp := SSHrimp{
Agent{
ProviderURL: "https://accounts.google.com",
Socket: "~/.ssh/sshrimp.toml",
Socket: "~/.ssh/sshrimp.sock",
Scopes: []string{"openid", "email", "profile"},
},
CertificateAuthority{

View File

@ -3,8 +3,10 @@ package identity
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"regexp"
"strings"
"unicode/utf8"
@ -13,6 +15,35 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
)
func init() {
// Disable log prefixes such as the default timestamp.
// Prefix text prevents the message from being parsed as JSON.
// A timestamp is added when shipping logs to Cloud Logging.
log.SetFlags(0)
}
// Entry defines a log entry.
type Entry struct {
Message string `json:"message"`
Severity string `json:"severity,omitempty"`
Trace string `json:"logging.googleapis.com/trace,omitempty"`
// Logs Explorer allows filtering and display of this as `jsonPayload.component`.
Component string `json:"component,omitempty"`
}
// String renders an entry structure to the JSON format expected by Cloud Logging.
func (e Entry) String() string {
if e.Severity == "" {
e.Severity = "INFO"
}
out, err := json.Marshal(e)
if err != nil {
log.Printf("json.Marshal: %v", err)
}
return string(out)
}
// Identity holds information required to verify an OIDC identity token
type Identity struct {
ctx context.Context
@ -98,15 +129,19 @@ f:
if ok {
usernames = append(usernames, name)
}
return usernames
return base64Decode(usernames)
}
fmt.Println(part)
log.Println(Entry{
Severity: "NOTICE",
Message: fmt.Sprintf("Fuck Off: %v", claims),
Component: part,
})
switch v := claims[part].(type) {
case map[string]interface{}:
claims = v
case []map[string]string:
fmt.Println("fuck zitadel")
for _, claimItem := range v {
name, ok := claimItem[parts[idx+1]]
if ok {
@ -123,9 +158,17 @@ f:
}
func base64Decode(names []string) []string {
for idx, name := range names {
decoded, err := base64.StdEncoding.Strict().DecodeString(name)
log.Println(Entry{
Severity: "NOTICE",
Message: fmt.Sprintf("Attempting to decode %q as base64\n", name),
})
decoded, err := base64.RawURLEncoding.DecodeString(name)
if err == nil && utf8.Valid(decoded) {
names[idx] = string(decoded)
log.Println(Entry{
Severity: "NOTICE",
Message: fmt.Sprintf("Successfully decoded %q as base64\n", names[idx]),
})
}
}
return names

View File

@ -78,7 +78,14 @@ func SignCertificateGCP(publicKey ssh.PublicKey, token string, forceCommand stri
return nil, err
}
result, err := http.Post(fmt.Sprintf("https://%s-%s.cloudfunctions.net/%s", region, c.CertificateAuthority.Project, c.CertificateAuthority.FunctionName), "application/json", bytes.NewReader(payload))
var uri string
if c.Agent.Url != "" {
uri = c.Agent.Url
} else {
uri = fmt.Sprintf("https://%s-%s.cloudfunctions.net/%s", region, c.CertificateAuthority.Project, c.CertificateAuthority.FunctionName)
}
result, err := http.Post(uri, "application/json", bytes.NewReader(payload))
if err != nil {
return nil, fmt.Errorf("http post failed: %w", err)
}
@ -114,10 +121,10 @@ func SignCertificateGCP(publicKey ssh.PublicKey, token string, forceCommand stri
// SignCertificateAWS given a public key, identity token and forceCommand, invoke the sshrimp-ca lambda function
func SignCertificateAWS(publicKey ssh.PublicKey, token string, forceCommand string, region string, c *config.SSHrimp) (*ssh.Certificate, error) {
// Create a lambdaService using the new temporary credentials for the role
session := session.Must(session.NewSession(&aws.Config{
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String(region),
}))
lambdaService := lambda.New(session)
lambdaService := lambda.New(sess)
// Setup the JSON payload for the SSHrimp CA
payload, err := json.Marshal(SSHrimpEvent{
@ -193,12 +200,12 @@ func ValidateRequest(event SSHrimpEvent, c *config.SSHrimp, requestID string, fu
}
// Generate a random nonce for the certificate
bytes := make([]byte, 32)
nonce := make([]byte, len(bytes)*2)
if _, err := rand.Read(bytes); err != nil {
nonceHex := make([]byte, 32)
nonce := make([]byte, len(nonceHex)*2)
if _, err := rand.Read(nonceHex); err != nil {
return ssh.Certificate{}, err
}
hex.Encode(nonce, bytes)
hex.Encode(nonce, nonceHex)
// Generate a random serial number
serial, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64))

View File

@ -95,7 +95,7 @@ func (o *OidcClient) setupHandlers() error {
provider, err := rp.NewRelyingPartyOIDC(o.Agent.ProviderURL, o.Agent.ClientID, o.Agent.ClientSecret, redirectURI.String(), o.Agent.Scopes, options...)
if err != nil {
return fmt.Errorf("Error creating provider: %w", err)
return fmt.Errorf("error creating provider: %w", err)
}
// generate some state (representing the state of the user in your application,

View File

@ -152,7 +152,7 @@ func (r *sshrimpAgent) SignWithFlags(key ssh.PublicKey, data []byte, flags agent
if ok {
if flags&agent.SignatureFlagRsaSha512 == agent.SignatureFlagRsaSha512 {
Log.Traceln("sha 512 requested")
s, err := sign.SignWithAlgorithm(rand.Reader, data, ssh.SigAlgoRSASHA2512)
s, err := sign.SignWithAlgorithm(rand.Reader, data, ssh.KeyAlgoRSASHA512)
if err == nil {
Log.Debugln("sha 512 available")
return s, nil
@ -160,7 +160,7 @@ func (r *sshrimpAgent) SignWithFlags(key ssh.PublicKey, data []byte, flags agent
}
if flags&agent.SignatureFlagRsaSha256 == agent.SignatureFlagRsaSha256 {
Log.Traceln("sha 256 requested")
s, err := sign.SignWithAlgorithm(rand.Reader, data, ssh.SigAlgoRSASHA2256)
s, err := sign.SignWithAlgorithm(rand.Reader, data, ssh.KeyAlgoRSASHA256)
if err == nil {
Log.Debugln("sha 256 available")
return s, nil

View File

@ -1,4 +1,5 @@
//+build mage
//go:build mage
// +build mage
package main

View File

@ -87,10 +87,10 @@ func PackageGCP() error {
defer zipFile.Close()
err = gcpCreateArchive(zipFile, []ZipFiles{
ZipFiles{Filename: "go.mod"},
{Filename: "go.mod"},
{"gcp/gcp.go", "gcp.go"},
ZipFiles{Filename: "internal"},
ZipFiles{config.GetPath(), filepath.Base(config.GetPath())},
{Filename: "internal"},
{config.GetPath(), filepath.Base(config.GetPath())},
}...)
if err != nil {