diff --git a/mealie/core/exceptions.py b/mealie/core/exceptions.py index 874141ed9..89ed9538e 100644 --- a/mealie/core/exceptions.py +++ b/mealie/core/exceptions.py @@ -43,3 +43,6 @@ def mealie_registered_exceptions(t: Translator) -> dict: class UserLockedOut(Exception): ... + + +class MissingClaimException(Exception): ... diff --git a/mealie/core/security/providers/openid_provider.py b/mealie/core/security/providers/openid_provider.py index f487f9ef4..a1124d248 100644 --- a/mealie/core/security/providers/openid_provider.py +++ b/mealie/core/security/providers/openid_provider.py @@ -5,6 +5,7 @@ from sqlalchemy.orm.session import Session from mealie.core import root_logger from mealie.core.config import get_app_settings +from mealie.core.exceptions import MissingClaimException from mealie.core.security.providers.auth_provider import AuthProvider from mealie.db.models.users.users import AuthMethod from mealie.repos.all_repositories import get_repositories @@ -25,7 +26,7 @@ class OpenIDProvider(AuthProvider[UserInfo]): claims = self.data if not claims: self._logger.error("[OIDC] No claims in the id_token") - return None + raise MissingClaimException() # Log all claims for debugging self._logger.debug("[OIDC] Received claims:") @@ -38,13 +39,13 @@ class OpenIDProvider(AuthProvider[UserInfo]): self.required_claims, claims.keys(), ) - return None + raise MissingClaimException() # Check for empty required claims for claim in self.required_claims: if not claims.get(claim): self._logger.error("[OIDC] Required claim '%s' is empty", claim) - return None + raise MissingClaimException() repos = get_repositories(self.session, group_id=None, household_id=None) diff --git a/mealie/routes/auth/auth.py b/mealie/routes/auth/auth.py index bca56c351..2e5b66174 100644 --- a/mealie/routes/auth/auth.py +++ b/mealie/routes/auth/auth.py @@ -11,7 +11,7 @@ from starlette.datastructures import URLPath from mealie.core import root_logger, security from mealie.core.config import get_app_settings from mealie.core.dependencies import get_current_user -from mealie.core.exceptions import UserLockedOut +from mealie.core.exceptions import MissingClaimException, UserLockedOut from mealie.core.security.providers.openid_provider import OpenIDProvider from mealie.core.security.security import get_auth_provider from mealie.db.db_setup import generate_session @@ -125,14 +125,24 @@ async def oauth_callback(request: Request, response: Response, session: Session detail="Could not initialize OAuth client", ) client = oauth.create_client("oidc") + token = await client.authorize_access_token(request) - auth_provider = OpenIDProvider(session, token["userinfo"]) - auth = auth_provider.authenticate() + + auth = None + try: + auth_provider = OpenIDProvider(session, token["userinfo"]) + auth = auth_provider.authenticate() + except MissingClaimException: + try: + logger.debug("[OIDC] Claims not present in the ID token, pulling user info") + userinfo = await client.userinfo(token=token) + auth_provider = OpenIDProvider(session, userinfo) + auth = auth_provider.authenticate() + except MissingClaimException: + auth = None if not auth: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED) access_token, duration = auth expires_in = duration.total_seconds() if duration else None diff --git a/tests/unit_tests/core/security/providers/test_openid_provider.py b/tests/unit_tests/core/security/providers/test_openid_provider.py index 7973ed655..0fcd92690 100644 --- a/tests/unit_tests/core/security/providers/test_openid_provider.py +++ b/tests/unit_tests/core/security/providers/test_openid_provider.py @@ -1,8 +1,10 @@ -import pytest -from pytest import MonkeyPatch, Session import logging +import pytest +from pytest import MonkeyPatch, Session + from mealie.core.config import get_app_settings +from mealie.core.exceptions import MissingClaimException from mealie.core.security.providers.openid_provider import OpenIDProvider from mealie.repos.all_repositories import get_repositories from tests.utils.factories import random_email, random_string @@ -12,13 +14,15 @@ from tests.utils.fixture_schemas import TestUser def test_no_claims(): auth_provider = OpenIDProvider(None, None) - assert auth_provider.authenticate() is None + with pytest.raises(MissingClaimException): + auth_provider.authenticate() def test_empty_claims(): auth_provider = OpenIDProvider(None, {}) - assert auth_provider.authenticate() is None + with pytest.raises(MissingClaimException): + auth_provider.authenticate() def test_empty_required_claims(): @@ -30,14 +34,16 @@ def test_empty_required_claims(): } auth_provider = OpenIDProvider(None, data) - assert auth_provider.authenticate() is None + with pytest.raises(MissingClaimException): + auth_provider.authenticate() def test_missing_claims(): data = {"preferred_username": "dude1"} auth_provider = OpenIDProvider(None, data) - assert auth_provider.authenticate() is None + with pytest.raises(MissingClaimException): + auth_provider.authenticate() def test_missing_groups_claim(monkeypatch: MonkeyPatch): @@ -51,7 +57,8 @@ def test_missing_groups_claim(monkeypatch: MonkeyPatch): } auth_provider = OpenIDProvider(None, data) - assert auth_provider.authenticate() is None + with pytest.raises(MissingClaimException): + auth_provider.authenticate() def test_missing_user_group(monkeypatch: MonkeyPatch):