From 46e8f10aead6e12988219da488d3f6567888da21 Mon Sep 17 00:00:00 2001 From: Chaim Lev Ari Date: Fri, 18 Jan 2019 10:56:16 +0200 Subject: [PATCH] refactor(ouath): use oauth2 library to get token --- api/http/handler/auth/authenticate_oauth.go | 13 ++- api/oauth/oauth.go | 96 +++------------------ api/portainer.go | 2 +- 3 files changed, 18 insertions(+), 93 deletions(-) diff --git a/api/http/handler/auth/authenticate_oauth.go b/api/http/handler/auth/authenticate_oauth.go index 43dcdc33c..d8a446d3e 100644 --- a/api/http/handler/auth/authenticate_oauth.go +++ b/api/http/handler/auth/authenticate_oauth.go @@ -49,17 +49,17 @@ func (handler *Handler) validateOAuth(w http.ResponseWriter, r *http.Request) *h return &httperror.HandlerError{http.StatusForbidden, "Unable to acquire username", portainer.ErrUnauthorized} } - u, err := handler.UserService.UserByUsername(username) + user, err := handler.UserService.UserByUsername(username) if err != nil && err != portainer.ErrObjectNotFound { return &httperror.HandlerError{http.StatusInternalServerError, "Unable to retrieve a user with the specified username from the database", err} } - if u == nil && !settings.OAuthSettings.OAuthAutoCreateUsers { + if user == nil && !settings.OAuthSettings.OAuthAutoCreateUsers { return &httperror.HandlerError{http.StatusForbidden, "Unregistered account", portainer.ErrUnauthorized} } - if u == nil { - user := &portainer.User{ + if user == nil { + user = &portainer.User{ Username: username, Role: portainer.StandardUserRole, } @@ -69,10 +69,9 @@ func (handler *Handler) validateOAuth(w http.ResponseWriter, r *http.Request) *h return &httperror.HandlerError{http.StatusInternalServerError, "Unable to persist user inside the database", err} } - return handler.writeToken(w, user) } - return handler.writeToken(w, u) + return handler.writeToken(w, user) } func (handler *Handler) loginOAuth(w http.ResponseWriter, r *http.Request) *httperror.HandlerError { @@ -85,7 +84,7 @@ func (handler *Handler) loginOAuth(w http.ResponseWriter, r *http.Request) *http return &httperror.HandlerError{http.StatusForbidden, "OAuth authentication is disabled", err} } - url := handler.OAuthService.BuildLoginURL(settings.OAuthSettings) + url := handler.OAuthService.BuildLoginURL(&settings.OAuthSettings) http.Redirect(w, r, url, http.StatusTemporaryRedirect) return nil } diff --git a/api/oauth/oauth.go b/api/oauth/oauth.go index 87439a636..c716ef90a 100644 --- a/api/oauth/oauth.go +++ b/api/oauth/oauth.go @@ -1,12 +1,10 @@ package oauth import ( + "context" "encoding/json" - "errors" "fmt" - "io" "io/ioutil" - "log" "mime" "net/http" "net/url" @@ -26,84 +24,9 @@ type Service struct{} // GetAccessToken takes an access code and exchanges it for an access token from portainer OAuthSettings token endpoint func (*Service) GetAccessToken(code string, settings *portainer.OAuthSettings) (string, error) { - v := url.Values{} - v.Set("client_id", settings.ClientID) - v.Set("client_secret", settings.ClientSecret) - v.Set("redirect_uri", settings.RedirectURI) - v.Set("code", code) - v.Set("grant_type", "authorization_code") - - req, err := http.NewRequest("POST", settings.AccessTokenURI, strings.NewReader(v.Encode())) - if err != nil { - return "", err - } - - client := &http.Client{} - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - r, err := client.Do(req) - if err != nil { - return "", err - } - - body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20)) - if err != nil { - return "", fmt.Errorf("oauth2: cannot fetch token: %v", err) - } - - if r.StatusCode != http.StatusOK { - type ErrorMessage struct { - Message string - Type string - Code int - } - type ErrorResponse struct { - Error ErrorMessage - } - - var response ErrorResponse - if err = json.Unmarshal(body, &response); err != nil { - // report also error - log.Printf("[Error] - Failed parsing error body: %v", err) - return "", errors.New("oauth2: cannot fetch token") - } - - return "", errors.New(response.Error.Message) - } - - content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type")) - if content == "application/x-www-form-urlencoded" || content == "text/plain" { - values, err := url.ParseQuery(string(body)) - if err != nil { - return "", err - } - - token := values.Get("access_token") - log.Printf("[DEBUG] - returned body %v", values) - - if token == "" { - log.Printf("[DEBUG] - access token returned empty - %v", values) - return "", errors.New("oauth2: cannot fetch token") - } - - return token, nil - } - - type tokenJSON struct { - AccessToken string `json:"access_token"` - } - - var tj tokenJSON - if err = json.Unmarshal(body, &tj); err != nil { - return "", err - } - - token := tj.AccessToken - - if token == "" { - log.Printf("[DEBUG] - access token returned empty - %v with status code", string(body), r.StatusCode) - return "", errors.New("oauth2: cannot fetch token") - } - return token, nil + config := buildConfig(settings) + token, err := config.Exchange(context.Background(), code) + return token.AccessToken, err } // GetUsername takes a token and retrieves the portainer OAuthSettings user identifier from resource server. @@ -167,19 +90,22 @@ func (*Service) GetUsername(token string, settings *portainer.OAuthSettings) (st } // BuildLoginURL creates a login url for the oauth provider -func (*Service) BuildLoginURL(oauthSettings portainer.OAuthSettings) string { +func (*Service) BuildLoginURL(oauthSettings *portainer.OAuthSettings) string { + oauthConfig := buildConfig(oauthSettings) + return oauthConfig.AuthCodeURL("portainer") +} + +func buildConfig(oauthSettings *portainer.OAuthSettings) *oauth2.Config { endpoint := oauth2.Endpoint{ AuthURL: oauthSettings.AuthorizationURI, TokenURL: oauthSettings.AccessTokenURI, } - oauthConfig := &oauth2.Config{ + return &oauth2.Config{ ClientID: oauthSettings.ClientID, ClientSecret: oauthSettings.ClientSecret, Endpoint: endpoint, RedirectURL: oauthSettings.RedirectURI, Scopes: strings.Split(oauthSettings.Scopes, ","), } - - return oauthConfig.AuthCodeURL("portainer") } diff --git a/api/portainer.go b/api/portainer.go index b5682c344..63f886f0f 100644 --- a/api/portainer.go +++ b/api/portainer.go @@ -766,7 +766,7 @@ type ( OAuthService interface { GetAccessToken(code string, settings *OAuthSettings) (string, error) GetUsername(token string, settings *OAuthSettings) (string, error) - BuildLoginURL(oauthSettings OAuthSettings) string + BuildLoginURL(oauthSettings *OAuthSettings) string } // SwarmStackManager represents a service to manage Swarm stacks