1
0
Fork 0
mirror of https://github.com/portainer/portainer.git synced 2025-07-19 13:29:41 +02:00

fix(oauth): analyze id_token for Azure [EE-2984] (#7000)

This commit is contained in:
Dmitry Salakhov 2022-07-06 13:22:57 +12:00 committed by GitHub
parent 0cd2a4558b
commit fd4b515350
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 428 additions and 46 deletions

View file

@ -3,16 +3,18 @@ package oauth
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"mime"
"net/http"
"net/url"
"strings"
"golang.org/x/oauth2"
"github.com/golang-jwt/jwt"
"github.com/pkg/errors"
portainer "github.com/portainer/portainer/api"
log "github.com/sirupsen/logrus"
)
// Service represents a service used to authenticate users against an authorization server
@ -29,17 +31,39 @@ func NewService() *Service {
func (*Service) Authenticate(code string, configuration *portainer.OAuthSettings) (string, error) {
token, err := getOAuthToken(code, configuration)
if err != nil {
log.Printf("[DEBUG] - Failed retrieving access token: %v", err)
log.Debugf("[internal,oauth] [message: failed retrieving oauth token: %v]", err)
return "", err
}
username, err := getUsername(token.AccessToken, configuration)
idToken, err := getIdToken(token)
if err != nil {
log.Printf("[DEBUG] - Failed retrieving oauth user name: %v", err)
log.Debugf("[internal,oauth] [message: failed parsing id_token: %v]", err)
}
resource, err := getResource(token.AccessToken, configuration)
if err != nil {
log.Debugf("[internal,oauth] [message: failed retrieving resource: %v]", err)
return "", err
}
resource = mergeSecondIntoFirst(idToken, resource)
username, err := getUsername(resource, configuration)
if err != nil {
log.Debugf("[internal,oauth] [message: failed retrieving username: %v]", err)
return "", err
}
return username, nil
}
// mergeSecondIntoFirst merges the overlap map into the base overwriting any existing values.
func mergeSecondIntoFirst(base map[string]interface{}, overlap map[string]interface{}) map[string]interface{} {
for k, v := range overlap {
base[k] = v
}
return base
}
func getOAuthToken(code string, configuration *portainer.OAuthSettings) (*oauth2.Token, error) {
unescapedCode, err := url.QueryUnescape(code)
if err != nil {
@ -55,27 +79,55 @@ func getOAuthToken(code string, configuration *portainer.OAuthSettings) (*oauth2
return token, nil
}
func getUsername(token string, configuration *portainer.OAuthSettings) (string, error) {
// getIdToken retrieves parsed id_token from the OAuth token response.
// This is necessary for OAuth providers like Azure
// that do not provide information about user groups on the user resource endpoint.
func getIdToken(token *oauth2.Token) (map[string]interface{}, error) {
tokenData := make(map[string]interface{})
idToken := token.Extra("id_token")
if idToken == nil {
return tokenData, nil
}
jwtParser := jwt.Parser{
SkipClaimsValidation: true,
}
t, _, err := jwtParser.ParseUnverified(idToken.(string), jwt.MapClaims{})
if err != nil {
return tokenData, errors.Wrap(err, "failed to parse id_token")
}
if claims, ok := t.Claims.(jwt.MapClaims); ok {
for k, v := range claims {
tokenData[k] = v
}
}
return tokenData, nil
}
func getResource(token string, configuration *portainer.OAuthSettings) (map[string]interface{}, error) {
req, err := http.NewRequest("GET", configuration.ResourceURI, nil)
if err != nil {
return "", err
return nil, err
}
client := &http.Client{}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := client.Do(req)
if err != nil {
return "", err
return nil, err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
return nil, err
}
if resp.StatusCode != http.StatusOK {
return "", &oauth2.RetrieveError{
return nil, &oauth2.RetrieveError{
Response: resp,
Body: body,
}
@ -83,47 +135,32 @@ func getUsername(token string, configuration *portainer.OAuthSettings) (string,
content, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err != nil {
return "", err
return nil, err
}
if content == "application/x-www-form-urlencoded" || content == "text/plain" {
values, err := url.ParseQuery(string(body))
if err != nil {
return "", err
return nil, err
}
username := values.Get(configuration.UserIdentifier)
if username == "" {
return username, &oauth2.RetrieveError{
Response: resp,
Body: body,
datamap := make(map[string]interface{})
for k, v := range values {
if len(v) == 0 {
datamap[k] = ""
} else {
datamap[k] = v[0]
}
}
return username, nil
return datamap, nil
}
var datamap map[string]interface{}
if err = json.Unmarshal(body, &datamap); err != nil {
return "", err
return nil, err
}
username, ok := datamap[configuration.UserIdentifier].(string)
if ok && username != "" {
return username, nil
}
if !ok {
username, ok := datamap[configuration.UserIdentifier].(float64)
if ok && username != 0 {
return fmt.Sprint(int(username)), nil
}
}
return "", &oauth2.RetrieveError{
Response: resp,
Body: body,
}
return datamap, nil
}
func buildConfig(configuration *portainer.OAuthSettings) *oauth2.Config {
@ -137,6 +174,6 @@ func buildConfig(configuration *portainer.OAuthSettings) *oauth2.Config {
ClientSecret: configuration.ClientSecret,
Endpoint: endpoint,
RedirectURL: configuration.RedirectURI,
Scopes: []string{configuration.Scopes},
Scopes: strings.Split(configuration.Scopes, ","),
}
}