From 966fca950b5115978c77d98d930f03f3f6a9547c Mon Sep 17 00:00:00 2001 From: andres-portainer <91705312+andres-portainer@users.noreply.github.com> Date: Mon, 28 Oct 2024 17:28:22 -0300 Subject: [PATCH] fix(oauth): add a timeout to getOAuthToken() BE-11283 (#63) --- api/oauth/oauth.go | 63 ++++++++++++++------------------ api/oauth/oauth_resource.go | 8 ++-- api/oauth/oauth_resource_test.go | 21 ++++------- api/oauth/oauth_test.go | 13 ++++--- 4 files changed, 44 insertions(+), 61 deletions(-) diff --git a/api/oauth/oauth.go b/api/oauth/oauth.go index 42f76dc12..ed3700b1e 100644 --- a/api/oauth/oauth.go +++ b/api/oauth/oauth.go @@ -3,10 +3,12 @@ package oauth import ( "context" "io" + "maps" "mime" "net/http" "net/url" "strings" + "time" portainer "github.com/portainer/portainer/api" @@ -29,28 +31,28 @@ func NewService() *Service { // On success, it will then return the username and token expiry time associated to authenticated user by fetching this information // from the resource server and matching it with the user identifier setting. func (*Service) Authenticate(code string, configuration *portainer.OAuthSettings) (string, error) { - token, err := getOAuthToken(code, configuration) + token, err := GetOAuthToken(code, configuration) if err != nil { log.Error().Err(err).Msg("failed retrieving oauth token") return "", err } - idToken, err := getIdToken(token) + idToken, err := GetIdToken(token) if err != nil { log.Error().Err(err).Msg("failed parsing id_token") } - resource, err := getResource(token.AccessToken, configuration) + resource, err := GetResource(token.AccessToken, configuration.ResourceURI) if err != nil { log.Error().Err(err).Msg("failed retrieving resource") return "", err } - resource = mergeSecondIntoFirst(idToken, resource) + maps.Copy(idToken, resource) - username, err := getUsername(resource, configuration) + username, err := GetUsername(resource, configuration.UserIdentifier) if err != nil { log.Error().Err(err).Msg("failed retrieving username") @@ -60,34 +62,24 @@ func (*Service) Authenticate(code string, configuration *portainer.OAuthSettings return username, nil } -// mergeSecondIntoFirst merges the overlap map into the base overwriting any existing values. -func mergeSecondIntoFirst(base map[string]any, overlap map[string]any) map[string]any { - for k, v := range overlap { - base[k] = v - } - - return base -} - -func getOAuthToken(code string, configuration *portainer.OAuthSettings) (*oauth2.Token, error) { +func GetOAuthToken(code string, configuration *portainer.OAuthSettings) (*oauth2.Token, error) { unescapedCode, err := url.QueryUnescape(code) if err != nil { return nil, err } config := buildConfig(configuration) - token, err := config.Exchange(context.Background(), unescapedCode) - if err != nil { - return nil, err - } - return token, nil + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + return config.Exchange(ctx, unescapedCode) } -// getIdToken retrieves parsed id_token from the OAuth token response. +// GetIdToken retrieves parsed id_token from the OAuth token response. // This is necessary for OAuth providers like Azure // that do not provide information about user groups on the user resource endpoint. -func getIdToken(token *oauth2.Token) (map[string]any, error) { +func GetIdToken(token *oauth2.Token) (map[string]any, error) { tokenData := make(map[string]any) idToken := token.Extra("id_token") @@ -113,8 +105,8 @@ func getIdToken(token *oauth2.Token) (map[string]any, error) { return tokenData, nil } -func getResource(token string, configuration *portainer.OAuthSettings) (map[string]any, error) { - req, err := http.NewRequest("GET", configuration.ResourceURI, nil) +func GetResource(token string, resourceURI string) (map[string]any, error) { + req, err := http.NewRequest(http.MethodGet, resourceURI, nil) if err != nil { return nil, err } @@ -159,6 +151,7 @@ func getResource(token string, configuration *portainer.OAuthSettings) (map[stri datamap[k] = v[0] } } + return datamap, nil } @@ -170,18 +163,16 @@ func getResource(token string, configuration *portainer.OAuthSettings) (map[stri return datamap, nil } -func buildConfig(configuration *portainer.OAuthSettings) *oauth2.Config { - endpoint := oauth2.Endpoint{ - AuthURL: configuration.AuthorizationURI, - TokenURL: configuration.AccessTokenURI, - AuthStyle: configuration.AuthStyle, - } - +func buildConfig(config *portainer.OAuthSettings) *oauth2.Config { return &oauth2.Config{ - ClientID: configuration.ClientID, - ClientSecret: configuration.ClientSecret, - Endpoint: endpoint, - RedirectURL: configuration.RedirectURI, - Scopes: strings.Split(configuration.Scopes, ","), + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + RedirectURL: config.RedirectURI, + Scopes: strings.Split(config.Scopes, ","), + Endpoint: oauth2.Endpoint{ + AuthURL: config.AuthorizationURI, + TokenURL: config.AccessTokenURI, + AuthStyle: config.AuthStyle, + }, } } diff --git a/api/oauth/oauth_resource.go b/api/oauth/oauth_resource.go index a850f1814..f0f1d7ad8 100644 --- a/api/oauth/oauth_resource.go +++ b/api/oauth/oauth_resource.go @@ -3,18 +3,16 @@ package oauth import ( "errors" "strconv" - - portainer "github.com/portainer/portainer/api" ) -func getUsername(datamap map[string]any, configuration *portainer.OAuthSettings) (string, error) { - username, ok := datamap[configuration.UserIdentifier].(string) +func GetUsername(datamap map[string]any, userIdentifier string) (string, error) { + username, ok := datamap[userIdentifier].(string) if ok && username != "" { return username, nil } if !ok { - username, ok := datamap[configuration.UserIdentifier].(float64) + username, ok := datamap[userIdentifier].(float64) if ok && username != 0 { return strconv.Itoa(int(username)), nil } diff --git a/api/oauth/oauth_resource_test.go b/api/oauth/oauth_resource_test.go index 58c7c2fbc..745bcaea0 100644 --- a/api/oauth/oauth_resource_test.go +++ b/api/oauth/oauth_resource_test.go @@ -11,8 +11,7 @@ func Test_getUsername(t *testing.T) { oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"} datamap := map[string]any{"name": "john"} - _, err := getUsername(datamap, oauthSettings) - if err == nil { + if _, err := GetUsername(datamap, oauthSettings.UserIdentifier); err == nil { t.Errorf("getUsername should fail if user identifier doesn't exist as key in oauth userinfo object") } }) @@ -21,8 +20,7 @@ func Test_getUsername(t *testing.T) { oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"} datamap := map[string]any{"username": ""} - _, err := getUsername(datamap, oauthSettings) - if err == nil { + if _, err := GetUsername(datamap, oauthSettings.UserIdentifier); err == nil { t.Errorf("getUsername should fail if username from oauth userinfo object is empty string") } }) @@ -31,8 +29,7 @@ func Test_getUsername(t *testing.T) { oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"} datamap := map[string]any{"username": 0} - _, err := getUsername(datamap, oauthSettings) - if err == nil { + if _, err := GetUsername(datamap, oauthSettings.UserIdentifier); err == nil { t.Errorf("getUsername should fail if username from oauth userinfo object is 0 val int") } }) @@ -41,8 +38,7 @@ func Test_getUsername(t *testing.T) { oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"} datamap := map[string]any{"username": -1} - _, err := getUsername(datamap, oauthSettings) - if err == nil { + if _, err := GetUsername(datamap, oauthSettings.UserIdentifier); err == nil { t.Errorf("getUsername should fail if username from oauth userinfo object is -1 (negative) int") } }) @@ -51,8 +47,7 @@ func Test_getUsername(t *testing.T) { oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"} datamap := map[string]any{"username": "john"} - _, err := getUsername(datamap, oauthSettings) - if err != nil { + if _, err := GetUsername(datamap, oauthSettings.UserIdentifier); err != nil { t.Errorf("getUsername should succeed if username from oauth userinfo object matched and non-empty") } }) @@ -62,8 +57,7 @@ func Test_getUsername(t *testing.T) { oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"} datamap := map[string]any{"username": 1} - _, err := getUsername(datamap, oauthSettings) - if err == nil { + if _, err := GetUsername(datamap, oauthSettings.UserIdentifier); err == nil { t.Errorf("getUsername should fail if username from oauth userinfo object matched is positive int") } }) @@ -72,8 +66,7 @@ func Test_getUsername(t *testing.T) { oauthSettings := &portainer.OAuthSettings{UserIdentifier: "username"} datamap := map[string]any{"username": 1.1} - _, err := getUsername(datamap, oauthSettings) - if err != nil { + if _, err := GetUsername(datamap, oauthSettings.UserIdentifier); err != nil { t.Errorf("getUsername should succeed if username from oauth userinfo object matched and non-zero (or negative)") } }) diff --git a/api/oauth/oauth_test.go b/api/oauth/oauth_test.go index bd70ce1be..6083ec773 100644 --- a/api/oauth/oauth_test.go +++ b/api/oauth/oauth_test.go @@ -5,6 +5,7 @@ import ( portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/oauth/oauthtest" + "github.com/stretchr/testify/assert" "golang.org/x/oauth2" ) @@ -16,14 +17,14 @@ func Test_getOAuthToken(t *testing.T) { t.Run("getOAuthToken fails upon invalid code", func(t *testing.T) { code := "" - if _, err := getOAuthToken(code, config); err == nil { + if _, err := GetOAuthToken(code, config); err == nil { t.Errorf("getOAuthToken should fail upon providing invalid code; code=%v", code) } }) t.Run("getOAuthToken succeeds upon providing valid code", func(t *testing.T) { code := validCode - token, err := getOAuthToken(code, config) + token, err := GetOAuthToken(code, config) if token == nil || err != nil { t.Errorf("getOAuthToken should successfully return access token upon providing valid code") @@ -78,7 +79,7 @@ func Test_getIdToken(t *testing.T) { token = token.WithExtra(map[string]any{"id_token": tc.idToken}) } - result, err := getIdToken(token) + result, err := GetIdToken(token) assert.Equal(t, err, tc.expectedError) assert.Equal(t, result, tc.expectedResult) }) @@ -90,19 +91,19 @@ func Test_getResource(t *testing.T) { defer srv.Close() t.Run("should fail upon missing Authorization Bearer header", func(t *testing.T) { - if _, err := getResource("", config); err == nil { + if _, err := GetResource("", config.ResourceURI); err == nil { t.Errorf("getResource should fail if access token is not provided in auth bearer header") } }) t.Run("should fail upon providing incorrect Authorization Bearer header", func(t *testing.T) { - if _, err := getResource("incorrect-token", config); err == nil { + if _, err := GetResource("incorrect-token", config.ResourceURI); err == nil { t.Errorf("getResource should fail if incorrect access token provided in auth bearer header") } }) t.Run("should succeed upon providing correct Authorization Bearer header", func(t *testing.T) { - if _, err := getResource(oauthtest.AccessToken, config); err != nil { + if _, err := GetResource(oauthtest.AccessToken, config.ResourceURI); err != nil { t.Errorf("getResource should succeed if correct access token provided in auth bearer header") } })