sshrimp/internal/identity/identity.go
Timmy Welch a9a40622ca
Add compatibility with Zitadel
Expand ~ to HOME in Agent.Socket
Add url override for gcloud functions v2
Add logging for parsing the principals
go fmt
2023-01-14 10:49:39 -08:00

176 lines
4.4 KiB
Go

package identity
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"regexp"
"strings"
"unicode/utf8"
"git.narnian.us/lordwelch/sshrimp/internal/config"
"github.com/coreos/go-oidc/v3/oidc"
)
func init() {
// Disable log prefixes such as the default timestamp.
// Prefix text prevents the message from being parsed as JSON.
// A timestamp is added when shipping logs to Cloud Logging.
log.SetFlags(0)
}
// Entry defines a log entry.
type Entry struct {
Message string `json:"message"`
Severity string `json:"severity,omitempty"`
Trace string `json:"logging.googleapis.com/trace,omitempty"`
// Logs Explorer allows filtering and display of this as `jsonPayload.component`.
Component string `json:"component,omitempty"`
}
// String renders an entry structure to the JSON format expected by Cloud Logging.
func (e Entry) String() string {
if e.Severity == "" {
e.Severity = "INFO"
}
out, err := json.Marshal(e)
if err != nil {
log.Printf("json.Marshal: %v", err)
}
return string(out)
}
// Identity holds information required to verify an OIDC identity token
type Identity struct {
ctx context.Context
verifier *oidc.IDTokenVerifier
usernameREs []*regexp.Regexp
usernameClaims []string
}
// NewIdentity return a new Identity, with default values and oidc proivder information populated
func NewIdentity(c *config.SSHrimp) (*Identity, error) {
ctx := context.Background()
provider, err := oidc.NewProvider(ctx, c.Agent.ProviderURL)
if err != nil {
return nil, err
}
oidcConfig := &oidc.Config{
ClientID: c.Agent.ClientID,
SupportedSigningAlgs: []string{"RS256"},
}
regexes := make([]*regexp.Regexp, 0, len(c.CertificateAuthority.UsernameRegexs))
for _, regex := range c.CertificateAuthority.UsernameRegexs {
regexes = append(regexes, regexp.MustCompile(regex))
}
return &Identity{
ctx: ctx,
verifier: provider.Verifier(oidcConfig),
usernameREs: regexes,
usernameClaims: c.CertificateAuthority.UsernameClaims,
}, nil
}
// Validate an identity token
func (i *Identity) Validate(token string) ([]string, error) {
idToken, err := i.verifier.Verify(i.ctx, token)
if err != nil {
return nil, errors.New("failed to verify identity token: " + err.Error())
}
return i.getUsernames(idToken)
}
func (i *Identity) getUsernames(idToken *oidc.IDToken) ([]string, error) {
var claims map[string]interface{}
if err := idToken.Claims(&claims); err != nil {
return nil, errors.New("failed to parse claims: " + err.Error())
}
usernames := make([]string, 0, len(i.usernameClaims))
for idx, claim := range i.usernameClaims {
claimedUsernames := getClaim(claim, claims)
if idx < len(i.usernameREs) {
for _, name := range claimedUsernames {
usernames = append(usernames, parseUsername(name, i.usernameREs[idx]))
}
} else {
usernames = append(usernames, claimedUsernames...)
}
}
if len(usernames) < 1 {
return nil, errors.New("configured username claim not in identity token")
}
return usernames, nil
}
func parseUsername(username string, re *regexp.Regexp) string {
if match := re.FindStringSubmatch(username); match != nil {
return match[1]
}
return ""
}
func getClaim(claim string, claims map[string]interface{}) []string {
usernames := make([]string, 0, 2)
parts := strings.Split(claim, ".")
f:
for idx, part := range parts {
if idx == len(parts)-1 {
name, ok := claims[part].(string)
if ok {
usernames = append(usernames, name)
}
return base64Decode(usernames)
}
fmt.Println(part)
log.Println(Entry{
Severity: "NOTICE",
Message: fmt.Sprintf("Fuck Off: %v", claims),
Component: part,
})
switch v := claims[part].(type) {
case map[string]interface{}:
claims = v
case []map[string]string:
for _, claimItem := range v {
name, ok := claimItem[parts[idx+1]]
if ok {
usernames = append(usernames, name)
}
}
break f
default:
break f
}
}
return base64Decode(usernames)
}
func base64Decode(names []string) []string {
for idx, name := range names {
log.Println(Entry{
Severity: "NOTICE",
Message: fmt.Sprintf("Attempting to decode %q as base64\n", name),
})
decoded, err := base64.RawURLEncoding.DecodeString(name)
if err == nil && utf8.Valid(decoded) {
names[idx] = string(decoded)
log.Println(Entry{
Severity: "NOTICE",
Message: fmt.Sprintf("Successfully decoded %q as base64\n", names[idx]),
})
}
}
return names
}