Store the whole oauth2 token content in keychain

This commit is contained in:
adrienperonnet 2019-04-18 15:36:38 +12:00
parent b11fe5c66f
commit 3c2e58c93e
2 changed files with 57 additions and 26 deletions

View File

@ -4,13 +4,12 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/99designs/keyring" "github.com/99designs/keyring"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts"
"github.com/stoggi/aws-oidc/provider" "github.com/stoggi/aws-oidc/provider"
kingpin "gopkg.in/alecthomas/kingpin.v2" "gopkg.in/alecthomas/kingpin.v2"
) )
type ExecConfig struct { type ExecConfig struct {
@ -22,7 +21,7 @@ type ExecConfig struct {
PKCE bool PKCE bool
Nonce bool Nonce bool
ReAuth bool ReAuth bool
AgentCommant []string AgentCommand []string
} }
// json metadata for AWS credential process. Ref: https://docs.aws.amazon.com/cli/latest/topic/config-vars.html#sourcing-credentials-from-external-processes // json metadata for AWS credential process. Ref: https://docs.aws.amazon.com/cli/latest/topic/config-vars.html#sourcing-credentials-from-external-processes
@ -81,7 +80,7 @@ func ConfigureExec(app *kingpin.Application, config *GlobalConfig) {
cmd.Arg("agent", "The executable and arguments of the local browser to use"). cmd.Arg("agent", "The executable and arguments of the local browser to use").
Default("open", "{}"). Default("open", "{}").
StringsVar(&execConfig.AgentCommant) StringsVar(&execConfig.AgentCommand)
cmd.Action(func(c *kingpin.ParseContext) error { cmd.Action(func(c *kingpin.ParseContext) error {
ExecCommand(app, config, &execConfig) ExecCommand(app, config, &execConfig)
@ -98,34 +97,43 @@ func ExecCommand(app *kingpin.Application, config *GlobalConfig, execConfig *Exe
PKCE: execConfig.PKCE, PKCE: execConfig.PKCE,
Nonce: execConfig.Nonce, Nonce: execConfig.Nonce,
ReAuth: execConfig.ReAuth, ReAuth: execConfig.ReAuth,
AgentCommand: execConfig.AgentCommant, AgentCommand: execConfig.AgentCommand,
} }
item, err := (*config.Keyring).Get(fmt.Sprintf("jwt-%s", execConfig.ClientID)) item, err := (*config.Keyring).Get(execConfig.ClientID)
if err != keyring.ErrKeyNotFound { if err != keyring.ErrKeyNotFound {
jwt := string(item.Data) oauth2Token := provider.Oauth2Token{}
accessKeyJson, err := assumeRoleWithWebIdentity(execConfig, jwt) err := json.Unmarshal(item.Data, &oauth2Token)
// Maybe fail silently in case oauth2 lib is not backward compatible
app.FatalIfError(err, "Error parsing Oauth2 token from token : %v", err)
accessKeyJson, err := assumeRoleWithWebIdentity(execConfig, &oauth2Token)
if err == nil { if err == nil {
fmt.Println(accessKeyJson) fmt.Println(accessKeyJson)
return return
} }
} }
authResult, err := provider.Authenticate(providerConfig) oauth2Token, err := provider.Authenticate(providerConfig)
app.FatalIfError(err, "Error authenticating to identity provider: %v", err)
accessKeyJson, err := assumeRoleWithWebIdentity(execConfig, oauth2Token)
app.FatalIfError(err, "Error assume role with web identity : %v", err)
json, err := json.Marshal(&oauth2Token)
app.FatalIfError(err, "Can't serialize Oauth2 token : %v", err)
accessKeyJson, err := assumeRoleWithWebIdentity(execConfig, authResult.JWT)
app.FatalIfError(err, "Error assume role with web identity", err)
(*config.Keyring).Set(keyring.Item{ (*config.Keyring).Set(keyring.Item{
Key: fmt.Sprintf("jwt-%s", execConfig.ClientID), Key: execConfig.ClientID,
Data: []byte(authResult.JWT), Data: json,
Label: fmt.Sprintf("JWT %s",execConfig.RoleArn), Label: fmt.Sprintf("Oauth2 token for %s",execConfig.RoleArn),
Description:"OIDC JWT", Description:"OIDC JWT",
}) })
fmt.Printf(accessKeyJson) fmt.Printf(accessKeyJson)
} }
func assumeRoleWithWebIdentity(execConfig *ExecConfig, jwt string) (string, error) {
func assumeRoleWithWebIdentity(execConfig *ExecConfig, oauth2Token *provider.Oauth2Token) (string, error) {
svc := sts.New(session.New()) svc := sts.New(session.New())
@ -133,7 +141,7 @@ func assumeRoleWithWebIdentity(execConfig *ExecConfig, jwt string) (string, erro
DurationSeconds: aws.Int64(execConfig.Duration), DurationSeconds: aws.Int64(execConfig.Duration),
RoleArn: aws.String(execConfig.RoleArn), RoleArn: aws.String(execConfig.RoleArn),
RoleSessionName: aws.String("aws-oidc"), RoleSessionName: aws.String("aws-oidc"),
WebIdentityToken: aws.String(jwt), WebIdentityToken: aws.String(oauth2Token.IDToken),
} }
assumeRoleResult, err := svc.AssumeRoleWithWebIdentity(input) assumeRoleResult, err := svc.AssumeRoleWithWebIdentity(input)

View File

@ -9,6 +9,7 @@ import (
"net/http" "net/http"
"os/exec" "os/exec"
"strings" "strings"
"time"
oidc "github.com/coreos/go-oidc" oidc "github.com/coreos/go-oidc"
@ -42,19 +43,31 @@ type TokenClaims struct {
Groups []string `json:"groups"` Groups []string `json:"groups"`
} }
func Authenticate(p *ProviderConfig) (Result, error) { type Oauth2Token struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
Expiry time.Time `json:"expiry,omitempty"`
IDToken string `json:"id_token,omitempty"`
}
func AuthenticateWithRefreshToken(p *ProviderConfig) {
}
func Authenticate(p *ProviderConfig) (*Oauth2Token, error) {
ctx := context.Background() ctx := context.Background()
resultChannel := make(chan Result, 0) resultChannel := make(chan *oauth2.Token, 0)
errorChannel := make(chan error, 0) errorChannel := make(chan error, 0)
provider, err := oidc.NewProvider(ctx, p.ProviderURL) provider, err := oidc.NewProvider(ctx, p.ProviderURL)
if err != nil { if err != nil {
return Result{"", nil, nil}, err return nil, err
} }
listener, err := net.Listen("tcp", "127.0.0.1:0") listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
return Result{"", nil, nil}, err return nil, err
} }
baseURL := "http://" + listener.Addr().String() baseURL := "http://" + listener.Addr().String()
redirectURL := baseURL + "/auth/callback" redirectURL := baseURL + "/auth/callback"
@ -75,13 +88,13 @@ func Authenticate(p *ProviderConfig) (Result, error) {
stateData := make([]byte, 32) stateData := make([]byte, 32)
if _, err = rand.Read(stateData); err != nil { if _, err = rand.Read(stateData); err != nil {
return Result{"", nil, nil}, err return nil, err
} }
state := base64.URLEncoding.EncodeToString(stateData) state := base64.URLEncoding.EncodeToString(stateData)
codeData := make([]byte, 32) codeData := make([]byte, 32)
if _, err = rand.Read(codeData); err != nil { if _, err = rand.Read(codeData); err != nil {
return Result{"", nil, nil}, err return nil, err
} }
codeVerifier := base64.StdEncoding.EncodeToString(codeData) codeVerifier := base64.StdEncoding.EncodeToString(codeData)
codeDigest := sha256.Sum256([]byte(codeVerifier)) codeDigest := sha256.Sum256([]byte(codeVerifier))
@ -156,7 +169,7 @@ func Authenticate(p *ProviderConfig) (Result, error) {
return return
} }
w.Write([]byte("Signed in successfully, return to cli app")) w.Write([]byte("Signed in successfully, return to cli app"))
resultChannel <- Result{rawIDToken, idToken, claims} resultChannel <- oauth2Token
}) })
// Filter the commands, and replace "{}" with our callback url // Filter the commands, and replace "{}" with our callback url
@ -179,9 +192,19 @@ func Authenticate(p *ProviderConfig) (Result, error) {
select { select {
case err := <-errorChannel: case err := <-errorChannel:
server.Shutdown(ctx) server.Shutdown(ctx)
return Result{}, err return nil, err
case res := <-resultChannel: case res := <-resultChannel:
server.Shutdown(ctx) server.Shutdown(ctx)
return res, nil idtoken, ok := res.Extra("id_token").(string)
if !ok {
return nil, errors.New("Can't extract id_token")
}
return &Oauth2Token{
AccessToken:res.AccessToken,
RefreshToken:res.RefreshToken,
Expiry:res.Expiry,
TokenType:res.TokenType,
IDToken:idtoken,
}, nil
} }
} }