254 lines
6.4 KiB
Go
254 lines
6.4 KiB
Go
package provider
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"errors"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os/exec"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/coreos/go-oidc"
|
|
|
|
"golang.org/x/net/context"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
type ProviderConfig struct {
|
|
ClientID string
|
|
ClientSecret string
|
|
ProviderURL string
|
|
PKCE bool
|
|
Nonce bool
|
|
AgentCommand []string
|
|
}
|
|
|
|
type Result struct {
|
|
JWT string
|
|
Token *oidc.IDToken
|
|
Claims *TokenClaims
|
|
}
|
|
|
|
type TokenClaims struct {
|
|
Issuer string `json:"iss"`
|
|
Audience string `json:"aud"`
|
|
Subject string `json:"sub"`
|
|
Picture string `json:"picture"`
|
|
Email string `json:"email"`
|
|
EmailVerified bool `json:"email_verified"`
|
|
Groups []string `json:"groups"`
|
|
}
|
|
|
|
type OAuth2Token struct {
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type,omitempty"`
|
|
RefreshToken string `json:"refresh_token,omitempty"`
|
|
Expiry time.Time `json:"expiry,omitempty"`
|
|
IDToken string `json:"id_token,omitempty"`
|
|
}
|
|
|
|
func refresh(config oauth2.Config, t *OAuth2Token) error {
|
|
ctx := context.Background()
|
|
|
|
tokenSourceToken := oauth2.Token{
|
|
AccessToken: t.AccessToken,
|
|
TokenType: t.TokenType,
|
|
RefreshToken: t.RefreshToken,
|
|
Expiry: t.Expiry,
|
|
}
|
|
ts := config.TokenSource(ctx, tokenSourceToken.WithExtra(map[string]interface{}{
|
|
"id_token": t.IDToken,
|
|
}))
|
|
|
|
res, err := ts.Token()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
idtoken, ok := res.Extra("id_token").(string)
|
|
if !ok {
|
|
return errors.New("can't extract id_token")
|
|
}
|
|
t.AccessToken = res.AccessToken
|
|
t.RefreshToken = res.RefreshToken
|
|
t.Expiry = res.Expiry
|
|
t.TokenType = res.TokenType
|
|
t.IDToken = idtoken
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p ProviderConfig) Authenticate(t *OAuth2Token) error {
|
|
ctx := context.Background()
|
|
resultChannel := make(chan *oauth2.Token)
|
|
errorChannel := make(chan error)
|
|
Mux := http.NewServeMux()
|
|
server := &http.Server{
|
|
Handler: Mux,
|
|
}
|
|
|
|
provider, err := oidc.NewProvider(ctx, p.ProviderURL)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer listener.Close()
|
|
baseURL := "http://" + listener.Addr().String()
|
|
redirectURL := baseURL + "/auth/callback"
|
|
|
|
oidcConfig := &oidc.Config{
|
|
ClientID: p.ClientID,
|
|
SupportedSigningAlgs: []string{"RS256"},
|
|
}
|
|
verifier := provider.Verifier(oidcConfig)
|
|
|
|
config := oauth2.Config{
|
|
ClientID: p.ClientID,
|
|
ClientSecret: p.ClientSecret,
|
|
Endpoint: provider.Endpoint(),
|
|
RedirectURL: redirectURL,
|
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
|
}
|
|
|
|
if t != nil {
|
|
if err := refresh(config, t); err == nil {
|
|
return nil
|
|
}
|
|
log.Println(err)
|
|
}
|
|
|
|
stateData := make([]byte, 32)
|
|
if _, err = rand.Read(stateData); err != nil {
|
|
return err
|
|
}
|
|
state := base64.URLEncoding.EncodeToString(stateData)
|
|
|
|
codeData := make([]byte, 32)
|
|
if _, err = rand.Read(codeData); err != nil {
|
|
return err
|
|
}
|
|
codeVerifier := base64.StdEncoding.EncodeToString(codeData)
|
|
codeDigest := sha256.Sum256([]byte(codeVerifier))
|
|
codeChallenge := base64.URLEncoding.EncodeToString(codeDigest[:])
|
|
codeChallengeEncoded := strings.Replace(codeChallenge, "=", "", -1)
|
|
|
|
nonceData := make([]byte, 32)
|
|
_, _ = rand.Read(nonceData)
|
|
nonce := base64.URLEncoding.EncodeToString(nonceData)
|
|
|
|
var authCodeOptions []oauth2.AuthCodeOption
|
|
var tokenCodeOptions []oauth2.AuthCodeOption
|
|
|
|
if p.PKCE {
|
|
authCodeOptions = append(authCodeOptions,
|
|
oauth2.SetAuthURLParam("code_challenge", codeChallengeEncoded),
|
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
|
)
|
|
tokenCodeOptions = append(tokenCodeOptions,
|
|
oauth2.SetAuthURLParam("code_verifier", codeVerifier),
|
|
)
|
|
}
|
|
|
|
if p.Nonce {
|
|
authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("nonce", nonce))
|
|
}
|
|
|
|
Mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
url := config.AuthCodeURL(state, authCodeOptions...)
|
|
http.Redirect(w, r, url, http.StatusFound)
|
|
})
|
|
|
|
Mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Query().Get("state") != state {
|
|
http.Error(w, "state did not match", http.StatusBadRequest)
|
|
errorChannel <- errors.New("state did not match")
|
|
return
|
|
}
|
|
|
|
oauth2Token, err := config.Exchange(ctx, r.URL.Query().Get("code"), tokenCodeOptions...)
|
|
if err != nil {
|
|
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
|
|
errorChannel <- errors.New("failed to exchange token: " + err.Error())
|
|
return
|
|
}
|
|
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
|
if !ok {
|
|
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
|
|
errorChannel <- errors.New("no id_token field in oauth2 token")
|
|
return
|
|
}
|
|
idToken, err := verifier.Verify(ctx, rawIDToken)
|
|
if err != nil {
|
|
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
|
|
errorChannel <- errors.New("failed to verify ID Token: " + err.Error())
|
|
return
|
|
}
|
|
if p.Nonce && idToken.Nonce != nonce {
|
|
http.Error(w, "Failed to verify Nonce", http.StatusInternalServerError)
|
|
errorChannel <- errors.New("failed to verify Nonce")
|
|
return
|
|
}
|
|
|
|
var claims = new(TokenClaims)
|
|
if err := idToken.Claims(&claims); err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
errorChannel <- errors.New("failed to verify Claims: " + err.Error())
|
|
return
|
|
}
|
|
w.Write([]byte("Signed in successfully, return to cli app"))
|
|
resultChannel <- oauth2Token
|
|
})
|
|
|
|
// Filter the commands, and replace "{}" with our callback url
|
|
c := make([]string, 0, len(p.AgentCommand))
|
|
replacedURL := false
|
|
for _, arg := range p.AgentCommand {
|
|
if arg == "{}" {
|
|
c = append(c, baseURL)
|
|
replacedURL = true
|
|
} else {
|
|
c = append(c, arg)
|
|
}
|
|
}
|
|
if !replacedURL {
|
|
c = append(c, baseURL)
|
|
}
|
|
|
|
//TODO Drop privileges
|
|
cmd := exec.Command(c[0], c[1:]...)
|
|
cmd.Start()
|
|
cmd.Process.Release()
|
|
|
|
go func() {
|
|
server.Serve(listener)
|
|
}()
|
|
|
|
select {
|
|
case err := <-errorChannel:
|
|
server.Shutdown(ctx)
|
|
return err
|
|
case res := <-resultChannel:
|
|
server.Shutdown(ctx)
|
|
IDToken, ok := res.Extra("id_token").(string)
|
|
if !ok {
|
|
return errors.New("can't extract id_token")
|
|
}
|
|
t.AccessToken = res.AccessToken
|
|
t.RefreshToken = res.RefreshToken
|
|
t.Expiry = res.Expiry
|
|
t.TokenType = res.TokenType
|
|
t.IDToken = IDToken
|
|
return nil
|
|
case <-time.After(2 * time.Minute):
|
|
server.Shutdown(ctx)
|
|
return errors.New("no oauth2 flow callback received within last 2 minutes, exiting")
|
|
}
|
|
}
|