mirror of
https://github.com/portainer/portainer.git
synced 2025-07-19 13:29:41 +02:00
fix(oauth): analyze id_token for Azure [EE-2984] (#7000)
This commit is contained in:
parent
0cd2a4558b
commit
fd4b515350
7 changed files with 428 additions and 46 deletions
|
@ -3,16 +3,18 @@ package oauth
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/pkg/errors"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Service represents a service used to authenticate users against an authorization server
|
||||
|
@ -29,17 +31,39 @@ func NewService() *Service {
|
|||
func (*Service) Authenticate(code string, configuration *portainer.OAuthSettings) (string, error) {
|
||||
token, err := getOAuthToken(code, configuration)
|
||||
if err != nil {
|
||||
log.Printf("[DEBUG] - Failed retrieving access token: %v", err)
|
||||
log.Debugf("[internal,oauth] [message: failed retrieving oauth token: %v]", err)
|
||||
return "", err
|
||||
}
|
||||
username, err := getUsername(token.AccessToken, configuration)
|
||||
|
||||
idToken, err := getIdToken(token)
|
||||
if err != nil {
|
||||
log.Printf("[DEBUG] - Failed retrieving oauth user name: %v", err)
|
||||
log.Debugf("[internal,oauth] [message: failed parsing id_token: %v]", err)
|
||||
}
|
||||
|
||||
resource, err := getResource(token.AccessToken, configuration)
|
||||
if err != nil {
|
||||
log.Debugf("[internal,oauth] [message: failed retrieving resource: %v]", err)
|
||||
return "", err
|
||||
}
|
||||
|
||||
resource = mergeSecondIntoFirst(idToken, resource)
|
||||
|
||||
username, err := getUsername(resource, configuration)
|
||||
if err != nil {
|
||||
log.Debugf("[internal,oauth] [message: failed retrieving username: %v]", err)
|
||||
return "", err
|
||||
}
|
||||
return username, nil
|
||||
}
|
||||
|
||||
// mergeSecondIntoFirst merges the overlap map into the base overwriting any existing values.
|
||||
func mergeSecondIntoFirst(base map[string]interface{}, overlap map[string]interface{}) map[string]interface{} {
|
||||
for k, v := range overlap {
|
||||
base[k] = v
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
func getOAuthToken(code string, configuration *portainer.OAuthSettings) (*oauth2.Token, error) {
|
||||
unescapedCode, err := url.QueryUnescape(code)
|
||||
if err != nil {
|
||||
|
@ -55,27 +79,55 @@ func getOAuthToken(code string, configuration *portainer.OAuthSettings) (*oauth2
|
|||
return token, nil
|
||||
}
|
||||
|
||||
func getUsername(token string, configuration *portainer.OAuthSettings) (string, error) {
|
||||
// getIdToken retrieves parsed id_token from the OAuth token response.
|
||||
// This is necessary for OAuth providers like Azure
|
||||
// that do not provide information about user groups on the user resource endpoint.
|
||||
func getIdToken(token *oauth2.Token) (map[string]interface{}, error) {
|
||||
tokenData := make(map[string]interface{})
|
||||
|
||||
idToken := token.Extra("id_token")
|
||||
if idToken == nil {
|
||||
return tokenData, nil
|
||||
}
|
||||
|
||||
jwtParser := jwt.Parser{
|
||||
SkipClaimsValidation: true,
|
||||
}
|
||||
|
||||
t, _, err := jwtParser.ParseUnverified(idToken.(string), jwt.MapClaims{})
|
||||
if err != nil {
|
||||
return tokenData, errors.Wrap(err, "failed to parse id_token")
|
||||
}
|
||||
|
||||
if claims, ok := t.Claims.(jwt.MapClaims); ok {
|
||||
for k, v := range claims {
|
||||
tokenData[k] = v
|
||||
}
|
||||
}
|
||||
return tokenData, nil
|
||||
}
|
||||
|
||||
func getResource(token string, configuration *portainer.OAuthSettings) (map[string]interface{}, error) {
|
||||
req, err := http.NewRequest("GET", configuration.ResourceURI, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", &oauth2.RetrieveError{
|
||||
return nil, &oauth2.RetrieveError{
|
||||
Response: resp,
|
||||
Body: body,
|
||||
}
|
||||
|
@ -83,47 +135,32 @@ func getUsername(token string, configuration *portainer.OAuthSettings) (string,
|
|||
|
||||
content, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if content == "application/x-www-form-urlencoded" || content == "text/plain" {
|
||||
values, err := url.ParseQuery(string(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
username := values.Get(configuration.UserIdentifier)
|
||||
if username == "" {
|
||||
return username, &oauth2.RetrieveError{
|
||||
Response: resp,
|
||||
Body: body,
|
||||
datamap := make(map[string]interface{})
|
||||
for k, v := range values {
|
||||
if len(v) == 0 {
|
||||
datamap[k] = ""
|
||||
} else {
|
||||
datamap[k] = v[0]
|
||||
}
|
||||
}
|
||||
|
||||
return username, nil
|
||||
return datamap, nil
|
||||
}
|
||||
|
||||
var datamap map[string]interface{}
|
||||
if err = json.Unmarshal(body, &datamap); err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
username, ok := datamap[configuration.UserIdentifier].(string)
|
||||
if ok && username != "" {
|
||||
return username, nil
|
||||
}
|
||||
|
||||
if !ok {
|
||||
username, ok := datamap[configuration.UserIdentifier].(float64)
|
||||
if ok && username != 0 {
|
||||
return fmt.Sprint(int(username)), nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", &oauth2.RetrieveError{
|
||||
Response: resp,
|
||||
Body: body,
|
||||
}
|
||||
return datamap, nil
|
||||
}
|
||||
|
||||
func buildConfig(configuration *portainer.OAuthSettings) *oauth2.Config {
|
||||
|
@ -137,6 +174,6 @@ func buildConfig(configuration *portainer.OAuthSettings) *oauth2.Config {
|
|||
ClientSecret: configuration.ClientSecret,
|
||||
Endpoint: endpoint,
|
||||
RedirectURL: configuration.RedirectURI,
|
||||
Scopes: []string{configuration.Scopes},
|
||||
Scopes: strings.Split(configuration.Scopes, ","),
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue