254 lines
6.4 KiB
Go

package provider
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"log"
"net"
"net/http"
"os/exec"
"strings"
"time"
"github.com/coreos/go-oidc"
"golang.org/x/net/context"
"golang.org/x/oauth2"
)
type ProviderConfig struct {
ClientID string
ClientSecret string
ProviderURL string
PKCE bool
Nonce bool
AgentCommand []string
}
type Result struct {
JWT string
Token *oidc.IDToken
Claims *TokenClaims
}
type TokenClaims struct {
Issuer string `json:"iss"`
Audience string `json:"aud"`
Subject string `json:"sub"`
Picture string `json:"picture"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Groups []string `json:"groups"`
}
type OAuth2Token struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
Expiry time.Time `json:"expiry,omitempty"`
IDToken string `json:"id_token,omitempty"`
}
func refresh(config oauth2.Config, t *OAuth2Token) error {
ctx := context.Background()
tokenSourceToken := oauth2.Token{
AccessToken: t.AccessToken,
TokenType: t.TokenType,
RefreshToken: t.RefreshToken,
Expiry: t.Expiry,
}
ts := config.TokenSource(ctx, tokenSourceToken.WithExtra(map[string]interface{}{
"id_token": t.IDToken,
}))
res, err := ts.Token()
if err != nil {
return err
}
idtoken, ok := res.Extra("id_token").(string)
if !ok {
return errors.New("can't extract id_token")
}
t.AccessToken = res.AccessToken
t.RefreshToken = res.RefreshToken
t.Expiry = res.Expiry
t.TokenType = res.TokenType
t.IDToken = idtoken
return nil
}
func (p ProviderConfig) Authenticate(t *OAuth2Token) error {
ctx := context.Background()
resultChannel := make(chan *oauth2.Token)
errorChannel := make(chan error)
Mux := http.NewServeMux()
server := &http.Server{
Handler: Mux,
}
provider, err := oidc.NewProvider(ctx, p.ProviderURL)
if err != nil {
return err
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return err
}
defer listener.Close()
baseURL := "http://" + listener.Addr().String()
redirectURL := baseURL + "/auth/callback"
oidcConfig := &oidc.Config{
ClientID: p.ClientID,
SupportedSigningAlgs: []string{"RS256"},
}
verifier := provider.Verifier(oidcConfig)
config := oauth2.Config{
ClientID: p.ClientID,
ClientSecret: p.ClientSecret,
Endpoint: provider.Endpoint(),
RedirectURL: redirectURL,
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
if t != nil {
if err := refresh(config, t); err == nil {
return nil
}
log.Println(err)
}
stateData := make([]byte, 32)
if _, err = rand.Read(stateData); err != nil {
return err
}
state := base64.URLEncoding.EncodeToString(stateData)
codeData := make([]byte, 32)
if _, err = rand.Read(codeData); err != nil {
return err
}
codeVerifier := base64.StdEncoding.EncodeToString(codeData)
codeDigest := sha256.Sum256([]byte(codeVerifier))
codeChallenge := base64.URLEncoding.EncodeToString(codeDigest[:])
codeChallengeEncoded := strings.Replace(codeChallenge, "=", "", -1)
nonceData := make([]byte, 32)
_, _ = rand.Read(nonceData)
nonce := base64.URLEncoding.EncodeToString(nonceData)
var authCodeOptions []oauth2.AuthCodeOption
var tokenCodeOptions []oauth2.AuthCodeOption
if p.PKCE {
authCodeOptions = append(authCodeOptions,
oauth2.SetAuthURLParam("code_challenge", codeChallengeEncoded),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
)
tokenCodeOptions = append(tokenCodeOptions,
oauth2.SetAuthURLParam("code_verifier", codeVerifier),
)
}
if p.Nonce {
authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("nonce", nonce))
}
Mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
url := config.AuthCodeURL(state, authCodeOptions...)
http.Redirect(w, r, url, http.StatusFound)
})
Mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("state") != state {
http.Error(w, "state did not match", http.StatusBadRequest)
errorChannel <- errors.New("state did not match")
return
}
oauth2Token, err := config.Exchange(ctx, r.URL.Query().Get("code"), tokenCodeOptions...)
if err != nil {
http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError)
errorChannel <- errors.New("failed to exchange token: " + err.Error())
return
}
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError)
errorChannel <- errors.New("no id_token field in oauth2 token")
return
}
idToken, err := verifier.Verify(ctx, rawIDToken)
if err != nil {
http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError)
errorChannel <- errors.New("failed to verify ID Token: " + err.Error())
return
}
if p.Nonce && idToken.Nonce != nonce {
http.Error(w, "Failed to verify Nonce", http.StatusInternalServerError)
errorChannel <- errors.New("failed to verify Nonce")
return
}
var claims = new(TokenClaims)
if err := idToken.Claims(&claims); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
errorChannel <- errors.New("failed to verify Claims: " + err.Error())
return
}
w.Write([]byte("Signed in successfully, return to cli app"))
resultChannel <- oauth2Token
})
// Filter the commands, and replace "{}" with our callback url
c := make([]string, 0, len(p.AgentCommand))
replacedURL := false
for _, arg := range p.AgentCommand {
if arg == "{}" {
c = append(c, baseURL)
replacedURL = true
} else {
c = append(c, arg)
}
}
if !replacedURL {
c = append(c, baseURL)
}
//TODO Drop privileges
cmd := exec.Command(c[0], c[1:]...)
cmd.Start()
cmd.Process.Release()
go func() {
server.Serve(listener)
}()
select {
case err := <-errorChannel:
server.Shutdown(ctx)
return err
case res := <-resultChannel:
server.Shutdown(ctx)
IDToken, ok := res.Extra("id_token").(string)
if !ok {
return errors.New("can't extract id_token")
}
t.AccessToken = res.AccessToken
t.RefreshToken = res.RefreshToken
t.Expiry = res.Expiry
t.TokenType = res.TokenType
t.IDToken = IDToken
return nil
case <-time.After(2 * time.Minute):
server.Shutdown(ctx)
return errors.New("no oauth2 flow callback received within last 2 minutes, exiting")
}
}