133 lines
3.2 KiB
Go
133 lines
3.2 KiB
Go
package identity
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"regexp"
|
|
"strings"
|
|
"unicode/utf8"
|
|
|
|
"git.narnian.us/lordwelch/sshrimp/internal/config"
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
)
|
|
|
|
// Identity holds information required to verify an OIDC identity token
|
|
type Identity struct {
|
|
ctx context.Context
|
|
verifier *oidc.IDTokenVerifier
|
|
usernameREs []*regexp.Regexp
|
|
usernameClaims []string
|
|
}
|
|
|
|
// NewIdentity return a new Identity, with default values and oidc proivder information populated
|
|
func NewIdentity(c *config.SSHrimp) (*Identity, error) {
|
|
ctx := context.Background()
|
|
provider, err := oidc.NewProvider(ctx, c.Agent.ProviderURL)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
oidcConfig := &oidc.Config{
|
|
ClientID: c.Agent.ClientID,
|
|
SupportedSigningAlgs: []string{"RS256"},
|
|
}
|
|
|
|
regexes := make([]*regexp.Regexp, 0, len(c.CertificateAuthority.UsernameRegexs))
|
|
for _, regex := range c.CertificateAuthority.UsernameRegexs {
|
|
regexes = append(regexes, regexp.MustCompile(regex))
|
|
}
|
|
|
|
return &Identity{
|
|
ctx: ctx,
|
|
verifier: provider.Verifier(oidcConfig),
|
|
usernameREs: regexes,
|
|
usernameClaims: c.CertificateAuthority.UsernameClaims,
|
|
}, nil
|
|
}
|
|
|
|
// Validate an identity token
|
|
func (i *Identity) Validate(token string) ([]string, error) {
|
|
|
|
idToken, err := i.verifier.Verify(i.ctx, token)
|
|
if err != nil {
|
|
return nil, errors.New("failed to verify identity token: " + err.Error())
|
|
}
|
|
return i.getUsernames(idToken)
|
|
}
|
|
|
|
func (i *Identity) getUsernames(idToken *oidc.IDToken) ([]string, error) {
|
|
var claims map[string]interface{}
|
|
if err := idToken.Claims(&claims); err != nil {
|
|
return nil, errors.New("failed to parse claims: " + err.Error())
|
|
}
|
|
usernames := make([]string, 0, len(i.usernameClaims))
|
|
for idx, claim := range i.usernameClaims {
|
|
|
|
claimedUsernames := getClaim(claim, claims)
|
|
|
|
if idx < len(i.usernameREs) {
|
|
for _, name := range claimedUsernames {
|
|
usernames = append(usernames, parseUsername(name, i.usernameREs[idx]))
|
|
}
|
|
} else {
|
|
usernames = append(usernames, claimedUsernames...)
|
|
}
|
|
}
|
|
if len(usernames) < 1 {
|
|
return nil, errors.New("configured username claim not in identity token")
|
|
}
|
|
return usernames, nil
|
|
}
|
|
|
|
func parseUsername(username string, re *regexp.Regexp) string {
|
|
if match := re.FindStringSubmatch(username); match != nil {
|
|
return match[1]
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func getClaim(claim string, claims map[string]interface{}) []string {
|
|
usernames := make([]string, 0, 2)
|
|
parts := strings.Split(claim, ".")
|
|
f:
|
|
for idx, part := range parts {
|
|
if idx == len(parts)-1 {
|
|
name, ok := claims[part].(string)
|
|
if ok {
|
|
usernames = append(usernames, name)
|
|
}
|
|
return usernames
|
|
}
|
|
|
|
fmt.Println(part)
|
|
switch v := claims[part].(type) {
|
|
case map[string]interface{}:
|
|
claims = v
|
|
case []map[string]string:
|
|
fmt.Println("fuck zitadel")
|
|
for _, claimItem := range v {
|
|
name, ok := claimItem[parts[idx+1]]
|
|
if ok {
|
|
usernames = append(usernames, name)
|
|
}
|
|
}
|
|
break f
|
|
default:
|
|
break f
|
|
}
|
|
|
|
}
|
|
return base64Decode(usernames)
|
|
}
|
|
func base64Decode(names []string) []string {
|
|
for idx, name := range names {
|
|
decoded, err := base64.StdEncoding.Strict().DecodeString(name)
|
|
if err == nil && utf8.Valid(decoded) {
|
|
names[idx] = string(decoded)
|
|
}
|
|
}
|
|
return names
|
|
}
|