1
0
Fork 0
mirror of https://github.com/mealie-recipes/mealie.git synced 2025-07-24 15:49:42 +02:00

fix: Lint Python code with ruff (#3799)

This commit is contained in:
Christian Clauss 2024-08-12 17:09:30 +02:00 committed by GitHub
parent 65ece35966
commit 432914e310
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 112 additions and 120 deletions

View file

@ -14,4 +14,5 @@ repos:
# Ruff version. # Ruff version.
rev: v0.5.7 rev: v0.5.7
hooks: hooks:
- id: ruff
- id: ruff-format - id: ruff-format

View file

@ -6,9 +6,6 @@ Create Date: 2023-02-10 21:18:32.405130
""" """
import sqlalchemy as sa
import mealie.db.migration_types
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.

View file

@ -11,7 +11,6 @@ from sqlalchemy import orm, select
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from text_unidecode import unidecode from text_unidecode import unidecode
import mealie.db.migration_types
from alembic import op from alembic import op
from mealie.db.models._model_utils.guid import GUID from mealie.db.models._model_utils.guid import GUID

View file

@ -8,7 +8,6 @@ Create Date: 2023-02-22 21:45:52.900964
import sqlalchemy as sa import sqlalchemy as sa
import mealie.db.migration_types
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.

View file

@ -6,12 +6,7 @@ Create Date: 2023-04-13 06:47:04.617131
""" """
import sqlalchemy as sa
import mealie.db.migration_types
from alembic import op from alembic import op
import alembic.context as context
from mealie.core.config import get_app_settings
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = "b3dbb554ba53" revision = "b3dbb554ba53"

View file

@ -66,7 +66,7 @@ def populate_shopping_list_users():
user_id = find_user_id_for_group(group_id) user_id = find_user_id_for_group(group_id)
if user_id: if user_id:
session.execute( session.execute(
sa.text(f"UPDATE shopping_lists SET user_id=:user_id WHERE id=:id").bindparams( sa.text("UPDATE shopping_lists SET user_id=:user_id WHERE id=:id").bindparams(
user_id=user_id, id=list_id user_id=user_id, id=list_id
) )
) )
@ -74,7 +74,7 @@ def populate_shopping_list_users():
logger.warning( logger.warning(
f"No user found for shopping list {list_id} with group {group_id}; deleting shopping list" f"No user found for shopping list {list_id} with group {group_id}; deleting shopping list"
) )
session.execute(sa.text(f"DELETE FROM shopping_lists WHERE id=:id").bindparams(id=list_id)) session.execute(sa.text("DELETE FROM shopping_lists WHERE id=:id").bindparams(id=list_id))
def upgrade(): def upgrade():

View file

@ -6,9 +6,6 @@ Create Date: 2024-03-10 05:08:32.397027
""" """
import sqlalchemy as sa
import mealie.db.migration_types
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.

View file

@ -32,7 +32,7 @@ def new_user_rating(user_id: Any, recipe_id: Any, rating: float | None = None, i
if is_postgres(): if is_postgres():
id = str(uuid4()) id = str(uuid4())
else: else:
id = "%.32x" % uuid4().int id = "%.32x" % uuid4().int # noqa: UP031
now = datetime.now(timezone.utc).isoformat() now = datetime.now(timezone.utc).isoformat()
return { return {

View file

@ -4,7 +4,7 @@ from pathlib import Path
from fastapi import FastAPI from fastapi import FastAPI
from jinja2 import Template from jinja2 import Template
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from utils import PROJECT_DIR, CodeTemplates, HTTPRequest, RouteObject, RequestType from utils import PROJECT_DIR, CodeTemplates, HTTPRequest, RouteObject
CWD = Path(__file__).parent CWD = Path(__file__).parent

View file

@ -26,7 +26,7 @@ class SQLiteProvider(AbstractDBProvider, BaseModel):
@property @property
def db_url(self) -> str: def db_url(self) -> str:
return f"sqlite:///{str(self.db_path.absolute())}" return f"sqlite:///{self.db_path.absolute()!s}"
@property @property
def db_url_public(self) -> str: def db_url_public(self) -> str:

View file

@ -70,7 +70,7 @@ async def get_token(
@user_router.get("/refresh") @user_router.get("/refresh")
async def refresh_token(current_user: PrivateUser = Depends(get_current_user)): async def refresh_token(current_user: PrivateUser = Depends(get_current_user)):
"""Use a valid token to get another token""" """Use a valid token to get another token"""
access_token = security.create_access_token(data=dict(sub=str(current_user.id))) access_token = security.create_access_token(data={"sub": str(current_user.id)})
return MealieAuthToken.respond(access_token) return MealieAuthToken.respond(access_token)

View file

@ -32,7 +32,7 @@ def make_dependable(cls):
return cls(*args, **kwargs) return cls(*args, **kwargs)
except (ValidationError, RequestValidationError) as e: except (ValidationError, RequestValidationError) as e:
for error in e.errors(): for error in e.errors():
error["loc"] = ["query"] + list(error["loc"]) error["loc"] = ["query", *list(error["loc"])]
raise HTTPException(422, detail=[format_exception(ex) for ex in e.errors()]) from None raise HTTPException(422, detail=[format_exception(ex) for ex in e.errors()]) from None
init_cls_and_handle_errors.__signature__ = signature(cls) init_cls_and_handle_errors.__signature__ = signature(cls)

View file

@ -53,7 +53,7 @@ class GroupService(BaseService):
all_ids = self.repos.recipes.all_ids(target_id) all_ids = self.repos.recipes.all_ids(target_id)
used_size = sum( used_size = sum(
fs_stats.get_dir_size(f"{self.directories.RECIPE_DATA_DIR}/{str(recipe_id)}") for recipe_id in all_ids fs_stats.get_dir_size(f"{self.directories.RECIPE_DATA_DIR}/{recipe_id!s}") for recipe_id in all_ids
) )
return GroupStorage.bytes(used_size, ALLOWED_SIZE) return GroupStorage.bytes(used_size, ALLOWED_SIZE)

View file

@ -191,7 +191,7 @@ class ShoppingListService:
created_items = self.list_items.create_many(filtered_create_items) if filtered_create_items else [] created_items = self.list_items.create_many(filtered_create_items) if filtered_create_items else []
updated_items = self.list_items.update_many(update_items) if update_items else [] updated_items = self.list_items.update_many(update_items) if update_items else []
for list_id in set(item.shopping_list_id for item in created_items + updated_items): for list_id in {item.shopping_list_id for item in created_items + updated_items}:
self.remove_unused_recipe_references(list_id) self.remove_unused_recipe_references(list_id)
return ShoppingListItemsCollectionOut( return ShoppingListItemsCollectionOut(
@ -278,7 +278,7 @@ class ShoppingListService:
self.list_items.delete_many(delete_items) if delete_items else [], # type: ignore self.list_items.delete_many(delete_items) if delete_items else [], # type: ignore
) )
for list_id in set(item.shopping_list_id for item in updated_items + deleted_items): for list_id in {item.shopping_list_id for item in updated_items + deleted_items}:
self.remove_unused_recipe_references(list_id) self.remove_unused_recipe_references(list_id)
return ShoppingListItemsCollectionOut( return ShoppingListItemsCollectionOut(
@ -291,7 +291,7 @@ class ShoppingListService:
self.list_items.delete_many(set(delete_items)) if delete_items else [], # type: ignore self.list_items.delete_many(set(delete_items)) if delete_items else [], # type: ignore
) )
for list_id in set(item.shopping_list_id for item in deleted_items): for list_id in {item.shopping_list_id for item in deleted_items}:
self.remove_unused_recipe_references(list_id) self.remove_unused_recipe_references(list_id)
return ShoppingListItemsCollectionOut(created_items=[], updated_items=[], deleted_items=deleted_items) return ShoppingListItemsCollectionOut(created_items=[], updated_items=[], deleted_items=deleted_items)

View file

@ -122,7 +122,7 @@ def parse_ingredient(tokens) -> tuple[str, str]:
# no opening bracket anywhere -> just ignore the last bracket # no opening bracket anywhere -> just ignore the last bracket
ingredient, note = parse_ingredient_with_comma(tokens) ingredient, note = parse_ingredient_with_comma(tokens)
else: else:
# opening bracket found -> split in ingredient and note, remove brackets from note # noqa: E501 # opening bracket found -> split in ingredient and note, remove brackets from note
note = " ".join(tokens[start:])[1:-1] note = " ".join(tokens[start:])[1:-1]
ingredient = " ".join(tokens[:start]) ingredient = " ".join(tokens[:start])
else: else:

View file

@ -95,7 +95,7 @@ def insideParenthesis(token, tokens):
else: else:
line = " ".join(tokens) line = " ".join(tokens)
return ( return (
re.match(r".*\(.*" + re.escape(token) + r".*\).*", line) is not None # noqa: W605 - invalid dscape sequence re.match(r".*\(.*" + re.escape(token) + r".*\).*", line) is not None # - invalid dscape sequence
) )
@ -188,7 +188,7 @@ def import_data(lines):
# turn B-NAME/123 back into "name" # turn B-NAME/123 back into "name"
tag, confidence = re.split(r"/", columns[-1], maxsplit=1) tag, confidence = re.split(r"/", columns[-1], maxsplit=1)
tag = re.sub(r"^[BI]\-", "", tag).lower() # noqa: W605 - invalid dscape sequence tag = re.sub(r"^[BI]\-", "", tag).lower() # - invalid dscape sequence
# ==================== # ====================
# Confidence Getter # Confidence Getter
@ -261,6 +261,6 @@ def export_data(lines):
for i, token in enumerate(tokens): for i, token in enumerate(tokens):
features = getFeatures(token, i + 1, tokens) features = getFeatures(token, i + 1, tokens)
output.append(joinLine([token] + features)) output.append(joinLine([token, *features]))
output.append("") output.append("")
return "\n".join(output) return "\n".join(output)

View file

@ -136,7 +136,7 @@ class RecipeDataService(BaseService):
if ext not in img.IMAGE_EXTENSIONS: if ext not in img.IMAGE_EXTENSIONS:
ext = "jpg" # Guess the extension ext = "jpg" # Guess the extension
file_name = f"{str(self.recipe_id)}.{ext}" file_name = f"{self.recipe_id!s}.{ext}"
file_path = Recipe.directory_from_id(self.recipe_id).joinpath("images", file_name) file_path = Recipe.directory_from_id(self.recipe_id).joinpath("images", file_name)
async with AsyncClient(transport=AsyncSafeTransport()) as client: async with AsyncClient(transport=AsyncSafeTransport()) as client:

View file

@ -162,7 +162,7 @@ def clean_instructions(steps_object: list | dict | str, default: list | None = N
# } # }
# #
steps_object = typing.cast(dict, steps_object) steps_object = typing.cast(dict, steps_object)
return clean_instructions([x for x in steps_object.values()]) return clean_instructions(list(steps_object.values()))
case str(step_as_str): case str(step_as_str):
# Strings are weird, some sites return a single string with newlines # Strings are weird, some sites return a single string with newlines
# others returns a json string for some reasons # others returns a json string for some reasons
@ -481,7 +481,7 @@ def clean_tags(data: str | list[str]) -> list[str]:
case [str(), *_]: case [str(), *_]:
return [tag.strip().title() for tag in data if tag.strip()] return [tag.strip().title() for tag in data if tag.strip()]
case str(data): case str(data):
return clean_tags([t for t in data.split(",")]) return clean_tags(data.split(","))
case _: case _:
return [] return []
# should probably raise exception # should probably raise exception

View file

@ -55,7 +55,7 @@ async def create_from_url(url: str, translator: Translator) -> tuple[Recipe, Scr
new_recipe.image = "no image" new_recipe.image = "no image"
if new_recipe.name is None or new_recipe.name == "": if new_recipe.name is None or new_recipe.name == "":
new_recipe.name = f"No Recipe Name Found - {str(uuid4())}" new_recipe.name = f"No Recipe Name Found - {uuid4()!s}"
new_recipe.slug = slugify(new_recipe.name) new_recipe.slug = slugify(new_recipe.name)
return new_recipe, extras return new_recipe, extras

View file

@ -139,23 +139,29 @@ target-version = "py310"
# Enable Pyflakes `E` and `F` codes by default. # Enable Pyflakes `E` and `F` codes by default.
ignore = ["F403", "TID252", "B008"] ignore = ["F403", "TID252", "B008"]
select = [ select = [
"B", # flake8-bugbear
"C4", # McCabe complexity
"C90", # flake8-comprehensions
"DTZ", # flake8-datetimez
"E", # pycodestyles "E", # pycodestyles
"F", # pyflakes "F", # pyflakes
"I", # isort "I", # isort
"T", # flake8-print "T", # flake8-print
"UP", # pyupgrade "UP", # pyupgrade
"B", # flake8-bugbear
"DTZ", # flake8-datetimez
# "ANN", # flake8-annotations # "ANN", # flake8-annotations
# "C", # McCabe complexity
# "RUF", # Ruff specific
# "BLE", # blind-except # "BLE", # blind-except
# "RUF", # Ruff specific
] ]
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E402", "E501"] "__init__.py" = ["E402", "E501"]
"alembic/versions/2022*" = ["E501"]
"alembic/versions/2023*" = ["E501"]
"dev/scripts/all_recipes_stress_test.py" = ["E501"]
"ldap_provider.py" = ["UP032"] "ldap_provider.py" = ["UP032"]
"tests/conftest.py" = ["E402"]
"tests/utils/routes/__init__.py" = ["F401"]
[tool.ruff.lint.mccabe] [tool.ruff.lint.mccabe]
# Unlike Flake8, default to a complexity level of 10. # Unlike Flake8, default to a complexity level of 10.
max-complexity = 10 max-complexity = 24 # Default is 10.

View file

@ -1,5 +1,5 @@
import json import json
from typing import Generator from collections.abc import Generator
from pytest import fixture from pytest import fixture
from starlette.testclient import TestClient from starlette.testclient import TestClient

View file

@ -26,7 +26,7 @@ def test_public_about_get_app_info(api_client: TestClient, is_private_group: boo
assert as_dict["allowSignup"] == settings.ALLOW_SIGNUP assert as_dict["allowSignup"] == settings.ALLOW_SIGNUP
if is_private_group: if is_private_group:
assert as_dict["defaultGroupSlug"] == None assert as_dict["defaultGroupSlug"] is None
else: else:
assert as_dict["defaultGroupSlug"] == group.slug assert as_dict["defaultGroupSlug"] == group.slug

View file

@ -104,7 +104,7 @@ def test_bad_mealie_alpha_data_is_ignored(api_client: TestClient, unique_user: T
with open(invalid_json_path, "w"): with open(invalid_json_path, "w"):
pass # write nothing to the file, which is invalid JSON pass # write nothing to the file, which is invalid JSON
except Exception: except Exception:
raise Exception(os.listdir(tmpdir)) raise Exception(os.listdir(tmpdir)) # noqa: B904
modified_test_data = os.path.join(tmpdir, "modified-test-data.zip") modified_test_data = os.path.join(tmpdir, "modified-test-data.zip")
with ZipFile(modified_test_data, "w") as zf: with ZipFile(modified_test_data, "w") as zf:

View file

@ -49,7 +49,7 @@ def test_group_recipe_actions_get_all(api_client: TestClient, unique_user: TestU
response = api_client.get(api_routes.groups_recipe_actions, headers=unique_user.token) response = api_client.get(api_routes.groups_recipe_actions, headers=unique_user.token)
data = assert_deserialize(response, 200) data = assert_deserialize(response, 200)
fetched_ids = set(item["id"] for item in data["items"]) fetched_ids = {item["id"] for item in data["items"]}
for expected_id in expected_ids: for expected_id in expected_ids:
assert expected_id in fetched_ids assert expected_id in fetched_ids

View file

@ -1,5 +1,5 @@
from collections.abc import Generator
from pathlib import Path from pathlib import Path
from typing import Generator
import pytest import pytest
import sqlalchemy import sqlalchemy

View file

@ -4,8 +4,8 @@ import os
import random import random
import shutil import shutil
import tempfile import tempfile
from collections.abc import Generator
from pathlib import Path from pathlib import Path
from typing import Generator
from uuid import uuid4 from uuid import uuid4
from zipfile import ZipFile from zipfile import ZipFile
@ -489,9 +489,9 @@ def test_duplicate(api_client: TestClient, recipe_data: RecipeSiteTestCase, uniq
# Ingredients should have the same texts, but different ids # Ingredients should have the same texts, but different ids
assert duplicate_recipe["recipeIngredient"] != initial_recipe["recipeIngredient"] assert duplicate_recipe["recipeIngredient"] != initial_recipe["recipeIngredient"]
assert list(map(lambda i: i["note"], duplicate_recipe["recipeIngredient"])) == list( assert [i["note"] for i in duplicate_recipe["recipeIngredient"]] == [
map(lambda i: i["note"], initial_recipe["recipeIngredient"]) i["note"] for i in initial_recipe["recipeIngredient"]
) ]
previous_categories = initial_recipe["recipeCategory"] previous_categories = initial_recipe["recipeCategory"]
assert duplicate_recipe["recipeCategory"] == previous_categories assert duplicate_recipe["recipeCategory"] == previous_categories
@ -748,21 +748,21 @@ def test_get_recipes_organizer_filter(
# get recipes by organizer # get recipes by organizer
if organizer_type == "tags": if organizer_type == "tags":
organizer = random.choice(tags) organizer = random.choice(tags)
expected_recipe_ids = set( expected_recipe_ids = {
str(recipe.id) for recipe in recipes if organizer.id in [tag.id for tag in recipe.tags or []] str(recipe.id) for recipe in recipes if organizer.id in [tag.id for tag in recipe.tags or []]
) }
elif organizer_type == "categories": elif organizer_type == "categories":
organizer = random.choice(categories) organizer = random.choice(categories)
expected_recipe_ids = set( expected_recipe_ids = {
str(recipe.id) str(recipe.id)
for recipe in recipes for recipe in recipes
if organizer.id in [category.id for category in recipe.recipe_category or []] if organizer.id in [category.id for category in recipe.recipe_category or []]
) }
elif organizer_type == "tools": elif organizer_type == "tools":
organizer = random.choice(tools) organizer = random.choice(tools)
expected_recipe_ids = set( expected_recipe_ids = {
str(recipe.id) for recipe in recipes if organizer.id in [tool.id for tool in recipe.tools or []] str(recipe.id) for recipe in recipes if organizer.id in [tool.id for tool in recipe.tools or []]
) }
else: else:
raise ValueError(f"Unknown organizer type: {organizer_type}") raise ValueError(f"Unknown organizer type: {organizer_type}")

View file

@ -1,6 +1,6 @@
from io import BytesIO
import json import json
import zipfile import zipfile
from io import BytesIO
from fastapi.testclient import TestClient from fastapi.testclient import TestClient

View file

@ -29,7 +29,7 @@ def test_recipe_ingredients_parser_nlp(api_client: TestClient, unique_user: Test
response = api_client.post(api_routes.parser_ingredients, json=payload, headers=unique_user.token) response = api_client.post(api_routes.parser_ingredients, json=payload, headers=unique_user.token)
assert response.status_code == 200 assert response.status_code == 200
for api_ingredient, test_ingredient in zip(response.json(), test_ingredients): for api_ingredient, test_ingredient in zip(response.json(), test_ingredients, strict=False):
assert_ingredient(api_ingredient, test_ingredient) assert_ingredient(api_ingredient, test_ingredient)

View file

@ -1,5 +1,5 @@
import random import random
from typing import Generator from collections.abc import Generator
from uuid import UUID from uuid import UUID
import pytest import pytest
@ -71,8 +71,8 @@ def test_user_recipe_favorites(
ratings = response.json()["ratings"] ratings = response.json()["ratings"]
assert len(ratings) == len(recipes_to_favorite) assert len(ratings) == len(recipes_to_favorite)
fetched_recipe_ids = set(rating["recipeId"] for rating in ratings) fetched_recipe_ids = {rating["recipeId"] for rating in ratings}
favorited_recipe_ids = set(str(recipe.id) for recipe in recipes_to_favorite) favorited_recipe_ids = {str(recipe.id) for recipe in recipes_to_favorite}
assert fetched_recipe_ids == favorited_recipe_ids assert fetched_recipe_ids == favorited_recipe_ids
# remove favorites # remove favorites
@ -87,8 +87,8 @@ def test_user_recipe_favorites(
ratings = response.json()["ratings"] ratings = response.json()["ratings"]
assert len(ratings) == len(recipes_to_favorite) - len(recipe_favorites_to_remove) assert len(ratings) == len(recipes_to_favorite) - len(recipe_favorites_to_remove)
fetched_recipe_ids = set(rating["recipeId"] for rating in ratings) fetched_recipe_ids = {rating["recipeId"] for rating in ratings}
removed_recipe_ids = set(str(recipe.id) for recipe in recipe_favorites_to_remove) removed_recipe_ids = {str(recipe.id) for recipe in recipe_favorites_to_remove}
assert fetched_recipe_ids == favorited_recipe_ids - removed_recipe_ids assert fetched_recipe_ids == favorited_recipe_ids - removed_recipe_ids

View file

@ -1,4 +1,4 @@
from typing import Generator from collections.abc import Generator
import pytest import pytest
import sqlalchemy import sqlalchemy

View file

@ -40,7 +40,7 @@ def test_get_all_users_admin(
assert response.status_code == 200 assert response.status_code == 200
# assert all users from all groups are returned # assert all users from all groups are returned
response_user_ids = set(user["id"] for user in response.json()["items"]) response_user_ids = {user["id"] for user in response.json()["items"]}
for user_id in user_ids: for user_id in user_ids:
assert user_id in response_user_ids assert user_id in response_user_ids
@ -73,7 +73,7 @@ def test_get_all_group_users(
user_group = database.groups.get_by_slug_or_id(user.group_id) user_group = database.groups.get_by_slug_or_id(user.group_id)
assert user_group assert user_group
same_group_user_ids: set[str] = set([str(user.user_id)]) same_group_user_ids: set[str] = {user.user_id}
for _ in range(random_int(2, 5)): for _ in range(random_int(2, 5)):
new_user = database.users.create( new_user = database.users.create(
{ {
@ -89,7 +89,7 @@ def test_get_all_group_users(
response = api_client.get(api_routes.users_group_users, params={"perPage": -1}, headers=user.token) response = api_client.get(api_routes.users_group_users, params={"perPage": -1}, headers=user.token)
assert response.status_code == 200 assert response.status_code == 200
response_user_ids = set(user["id"] for user in response.json()["items"]) response_user_ids = {user["id"] for user in response.json()["items"]}
# assert only users from the same group are returned # assert only users from the same group are returned
for user_id in other_group_user_ids: for user_id in other_group_user_ids:

View file

@ -234,7 +234,10 @@ def test_ldap_user_login_simple_filter(api_client: TestClient):
@pytest.mark.skipif(not os.environ.get("GITHUB_ACTIONS", False), reason="requires ldap service in github actions") @pytest.mark.skipif(not os.environ.get("GITHUB_ACTIONS", False), reason="requires ldap service in github actions")
def test_ldap_user_login_complex_filter(api_client: TestClient): def test_ldap_user_login_complex_filter(api_client: TestClient):
settings = get_app_settings() settings = get_app_settings()
settings.LDAP_USER_FILTER = "(&(objectClass=inetOrgPerson)(|(memberOf=cn=ship_crew,ou=people,dc=planetexpress,dc=com)(memberOf=cn=admin_staff,ou=people,dc=planetexpress,dc=com)))" settings.LDAP_USER_FILTER = (
"(&(objectClass=inetOrgPerson)(|(memberOf=cn=ship_crew,ou=people,dc=planetexpress,dc=com)"
"(memberOf=cn=admin_staff,ou=people,dc=planetexpress,dc=com)))"
)
form_data = {"username": "professor", "password": "professor"} form_data = {"username": "professor", "password": "professor"}
response = api_client.post(api_routes.auth_token, data=form_data) response = api_client.post(api_routes.auth_token, data=form_data)

View file

@ -3,14 +3,11 @@ import json
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from mealie.core.config import get_app_settings
from mealie.db.db_setup import session_context from mealie.db.db_setup import session_context
from mealie.repos.repository_factory import AllRepositories from mealie.schema.user.user import PrivateUser
from mealie.schema.response.pagination import PaginationQuery
from mealie.schema.user.user import ChangePassword, PrivateUser
from mealie.services.user_services.password_reset_service import PasswordResetService from mealie.services.user_services.password_reset_service import PasswordResetService
from tests.utils import api_routes from tests.utils import api_routes
from tests.utils.factories import random_email, random_string from tests.utils.factories import random_string
from tests.utils.fixture_schemas import TestUser from tests.utils.fixture_schemas import TestUser
@ -27,7 +24,7 @@ def test_password_reset(api_client: TestClient, unique_user: TestUser, casing: s
cased_email += letter.upper() cased_email += letter.upper()
else: else:
cased_email += letter.lower() cased_email += letter.lower()
cased_email assert cased_email
with session_context() as session: with session_context() as session:
service = PasswordResetService(session) service = PasswordResetService(session)
@ -75,7 +72,7 @@ def test_password_reset_ldap(ldap_user: PrivateUser, casing: str):
cased_email += letter.upper() cased_email += letter.upper()
else: else:
cased_email += letter.lower() cased_email += letter.lower()
cased_email assert cased_email
with session_context() as session: with session_context() as session:
service = PasswordResetService(session) service = PasswordResetService(session)

View file

@ -24,7 +24,7 @@ class ABCMultiTenantTestCase(ABC):
@abstractmethod @abstractmethod
def cleanup(self) -> None: ... def cleanup(self) -> None: ...
def __enter__(self): def __enter__(self): # noqa: B027
pass pass
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):

View file

@ -65,8 +65,8 @@ def test_multitenant_cases_same_named_resources(
): ):
""" """
This test is used to ensure that the same resource can be created with the same values in different tenants. This test is used to ensure that the same resource can be created with the same values in different tenants.
i.e. the same category can exist in multiple groups. This is important to validate that the compound unique constraints i.e. the same category can exist in multiple groups. This is important to validate that the compound unique
are operating in SQLAlchemy correctly. constraints are operating in SQLAlchemy correctly.
""" """
user1 = multitenants.user_one user1 = multitenants.user_one
user2 = multitenants.user_two user2 = multitenants.user_two

View file

@ -308,7 +308,7 @@ def test_pagination_filter_in_advanced(database: AllRepositories, unique_user: T
TagSave(group_id=unique_user.group_id, name=slug2, slug=slug2), TagSave(group_id=unique_user.group_id, name=slug2, slug=slug2),
] ]
tag_1, tag_2 = [database.tags.create(tag) for tag in tags] tag_1, tag_2 = (database.tags.create(tag) for tag in tags)
# Bootstrap the database with recipes # Bootstrap the database with recipes
slug = random_string() slug = random_string()
@ -472,7 +472,7 @@ def test_pagination_filter_logical_namespace_conflict(database: AllRepositories,
CategorySave(group_id=unique_user.group_id, name=random_string(10)), CategorySave(group_id=unique_user.group_id, name=random_string(10)),
CategorySave(group_id=unique_user.group_id, name=random_string(10)), CategorySave(group_id=unique_user.group_id, name=random_string(10)),
] ]
category_1, category_2 = [database.categories.create(category) for category in categories] category_1, category_2 = (database.categories.create(category) for category in categories)
# Bootstrap the database with recipes # Bootstrap the database with recipes
slug = random_string() slug = random_string()
@ -528,7 +528,7 @@ def test_pagination_filter_datetimes(
dt = past_dt.isoformat() dt = past_dt.isoformat()
query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>"{dt}"') query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>"{dt}"')
unit_results = units_repo.page_all(query).items unit_results = units_repo.page_all(query).items
unit_ids = set(unit.id for unit in unit_results) unit_ids = {unit.id for unit in unit_results}
assert len(unit_ids) == 3 assert len(unit_ids) == 3
assert unit_1.id in unit_ids assert unit_1.id in unit_ids
assert unit_2.id in unit_ids assert unit_2.id in unit_ids
@ -537,7 +537,7 @@ def test_pagination_filter_datetimes(
dt = unit_1.created_at.isoformat() # type: ignore dt = unit_1.created_at.isoformat() # type: ignore
query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>"{dt}"') query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>"{dt}"')
unit_results = units_repo.page_all(query).items unit_results = units_repo.page_all(query).items
unit_ids = set(unit.id for unit in unit_results) unit_ids = {unit.id for unit in unit_results}
assert len(unit_ids) == 2 assert len(unit_ids) == 2
assert unit_1.id not in unit_ids assert unit_1.id not in unit_ids
assert unit_2.id in unit_ids assert unit_2.id in unit_ids
@ -546,7 +546,7 @@ def test_pagination_filter_datetimes(
dt = unit_2.created_at.isoformat() # type: ignore dt = unit_2.created_at.isoformat() # type: ignore
query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>"{dt}"') query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>"{dt}"')
unit_results = units_repo.page_all(query).items unit_results = units_repo.page_all(query).items
unit_ids = set(unit.id for unit in unit_results) unit_ids = {unit.id for unit in unit_results}
assert len(unit_ids) == 1 assert len(unit_ids) == 1
assert unit_1.id not in unit_ids assert unit_1.id not in unit_ids
assert unit_2.id not in unit_ids assert unit_2.id not in unit_ids
@ -555,14 +555,14 @@ def test_pagination_filter_datetimes(
dt = unit_3.created_at.isoformat() # type: ignore dt = unit_3.created_at.isoformat() # type: ignore
query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>"{dt}"') query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>"{dt}"')
unit_results = units_repo.page_all(query).items unit_results = units_repo.page_all(query).items
unit_ids = set(unit.id for unit in unit_results) unit_ids = {unit.id for unit in unit_results}
assert len(unit_ids) == 0 assert len(unit_ids) == 0
future_dt: datetime = unit_3.created_at + timedelta(seconds=1) # type: ignore future_dt: datetime = unit_3.created_at + timedelta(seconds=1) # type: ignore
dt = future_dt.isoformat() dt = future_dt.isoformat()
query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>"{dt}"') query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>"{dt}"')
unit_results = units_repo.page_all(query).items unit_results = units_repo.page_all(query).items
unit_ids = set(unit.id for unit in unit_results) unit_ids = {unit.id for unit in unit_results}
assert len(unit_ids) == 0 assert len(unit_ids) == 0
## GTE ## GTE
@ -570,7 +570,7 @@ def test_pagination_filter_datetimes(
dt = past_dt.isoformat() dt = past_dt.isoformat()
query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"') query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"')
unit_results = units_repo.page_all(query).items unit_results = units_repo.page_all(query).items
unit_ids = set(unit.id for unit in unit_results) unit_ids = {unit.id for unit in unit_results}
assert len(unit_ids) == 3 assert len(unit_ids) == 3
assert unit_1.id in unit_ids assert unit_1.id in unit_ids
assert unit_2.id in unit_ids assert unit_2.id in unit_ids
@ -579,7 +579,7 @@ def test_pagination_filter_datetimes(
dt = unit_1.created_at.isoformat() # type: ignore dt = unit_1.created_at.isoformat() # type: ignore
query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"') query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"')
unit_results = units_repo.page_all(query).items unit_results = units_repo.page_all(query).items
unit_ids = set(unit.id for unit in unit_results) unit_ids = {unit.id for unit in unit_results}
assert len(unit_ids) == 3 assert len(unit_ids) == 3
assert unit_1.id in unit_ids assert unit_1.id in unit_ids
assert unit_2.id in unit_ids assert unit_2.id in unit_ids
@ -588,7 +588,7 @@ def test_pagination_filter_datetimes(
dt = unit_2.created_at.isoformat() # type: ignore dt = unit_2.created_at.isoformat() # type: ignore
query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"') query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"')
unit_results = units_repo.page_all(query).items unit_results = units_repo.page_all(query).items
unit_ids = set(unit.id for unit in unit_results) unit_ids = {unit.id for unit in unit_results}
assert len(unit_ids) == 2 assert len(unit_ids) == 2
assert unit_1.id not in unit_ids assert unit_1.id not in unit_ids
assert unit_2.id in unit_ids assert unit_2.id in unit_ids
@ -597,7 +597,7 @@ def test_pagination_filter_datetimes(
dt = unit_3.created_at.isoformat() # type: ignore dt = unit_3.created_at.isoformat() # type: ignore
query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"') query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"')
unit_results = units_repo.page_all(query).items unit_results = units_repo.page_all(query).items
unit_ids = set(unit.id for unit in unit_results) unit_ids = {unit.id for unit in unit_results}
assert len(unit_ids) == 1 assert len(unit_ids) == 1
assert unit_1.id not in unit_ids assert unit_1.id not in unit_ids
assert unit_2.id not in unit_ids assert unit_2.id not in unit_ids
@ -607,7 +607,7 @@ def test_pagination_filter_datetimes(
dt = future_dt.isoformat() dt = future_dt.isoformat()
query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"') query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"')
unit_results = units_repo.page_all(query).items unit_results = units_repo.page_all(query).items
unit_ids = set(unit.id for unit in unit_results) unit_ids = {unit.id for unit in unit_results}
assert len(unit_ids) == 0 assert len(unit_ids) == 0
@ -931,7 +931,7 @@ def test_pagination_filter_dates(api_client: TestClient, unique_user: TestUser):
## Yesterday ## Yesterday
params = { params = {
f"page": 1, "page": 1,
"perPage": -1, "perPage": -1,
"queryFilter": f"date >= {yesterday.strftime('%Y-%m-%d')}", "queryFilter": f"date >= {yesterday.strftime('%Y-%m-%d')}",
} }
@ -940,12 +940,12 @@ def test_pagination_filter_dates(api_client: TestClient, unique_user: TestUser):
response_json = response.json() response_json = response.json()
assert len(response_json["items"]) == 2 assert len(response_json["items"]) == 2
fetched_mealplan_titles = set(mp["title"] for mp in response_json["items"]) fetched_mealplan_titles = {mp["title"] for mp in response_json["items"]}
assert mealplan_today.title in fetched_mealplan_titles assert mealplan_today.title in fetched_mealplan_titles
assert mealplan_tomorrow.title in fetched_mealplan_titles assert mealplan_tomorrow.title in fetched_mealplan_titles
params = { params = {
f"page": 1, "page": 1,
"perPage": -1, "perPage": -1,
"queryFilter": f"date > {yesterday.strftime('%Y-%m-%d')}", "queryFilter": f"date > {yesterday.strftime('%Y-%m-%d')}",
} }
@ -954,13 +954,13 @@ def test_pagination_filter_dates(api_client: TestClient, unique_user: TestUser):
response_json = response.json() response_json = response.json()
assert len(response_json["items"]) == 2 assert len(response_json["items"]) == 2
fetched_mealplan_titles = set(mp["title"] for mp in response_json["items"]) fetched_mealplan_titles = {mp["title"] for mp in response_json["items"]}
assert mealplan_today.title in fetched_mealplan_titles assert mealplan_today.title in fetched_mealplan_titles
assert mealplan_tomorrow.title in fetched_mealplan_titles assert mealplan_tomorrow.title in fetched_mealplan_titles
## Today ## Today
params = { params = {
f"page": 1, "page": 1,
"perPage": -1, "perPage": -1,
"queryFilter": f"date >= {today.strftime('%Y-%m-%d')}", "queryFilter": f"date >= {today.strftime('%Y-%m-%d')}",
} }
@ -969,12 +969,12 @@ def test_pagination_filter_dates(api_client: TestClient, unique_user: TestUser):
response_json = response.json() response_json = response.json()
assert len(response_json["items"]) == 2 assert len(response_json["items"]) == 2
fetched_mealplan_titles = set(mp["title"] for mp in response_json["items"]) fetched_mealplan_titles = {mp["title"] for mp in response_json["items"]}
assert mealplan_today.title in fetched_mealplan_titles assert mealplan_today.title in fetched_mealplan_titles
assert mealplan_tomorrow.title in fetched_mealplan_titles assert mealplan_tomorrow.title in fetched_mealplan_titles
params = { params = {
f"page": 1, "page": 1,
"perPage": -1, "perPage": -1,
"queryFilter": f"date > {today.strftime('%Y-%m-%d')}", "queryFilter": f"date > {today.strftime('%Y-%m-%d')}",
} }
@ -983,13 +983,13 @@ def test_pagination_filter_dates(api_client: TestClient, unique_user: TestUser):
response_json = response.json() response_json = response.json()
assert len(response_json["items"]) == 1 assert len(response_json["items"]) == 1
fetched_mealplan_titles = set(mp["title"] for mp in response_json["items"]) fetched_mealplan_titles = {mp["title"] for mp in response_json["items"]}
assert mealplan_today.title not in fetched_mealplan_titles assert mealplan_today.title not in fetched_mealplan_titles
assert mealplan_tomorrow.title in fetched_mealplan_titles assert mealplan_tomorrow.title in fetched_mealplan_titles
## Tomorrow ## Tomorrow
params = { params = {
f"page": 1, "page": 1,
"perPage": -1, "perPage": -1,
"queryFilter": f"date >= {tomorrow.strftime('%Y-%m-%d')}", "queryFilter": f"date >= {tomorrow.strftime('%Y-%m-%d')}",
} }
@ -998,12 +998,12 @@ def test_pagination_filter_dates(api_client: TestClient, unique_user: TestUser):
response_json = response.json() response_json = response.json()
assert len(response_json["items"]) == 1 assert len(response_json["items"]) == 1
fetched_mealplan_titles = set(mp["title"] for mp in response_json["items"]) fetched_mealplan_titles = {mp["title"] for mp in response_json["items"]}
assert mealplan_today.title not in fetched_mealplan_titles assert mealplan_today.title not in fetched_mealplan_titles
assert mealplan_tomorrow.title in fetched_mealplan_titles assert mealplan_tomorrow.title in fetched_mealplan_titles
params = { params = {
f"page": 1, "page": 1,
"perPage": -1, "perPage": -1,
"queryFilter": f"date > {tomorrow.strftime('%Y-%m-%d')}", "queryFilter": f"date > {tomorrow.strftime('%Y-%m-%d')}",
} }
@ -1015,7 +1015,7 @@ def test_pagination_filter_dates(api_client: TestClient, unique_user: TestUser):
## Day After Tomorrow ## Day After Tomorrow
params = { params = {
f"page": 1, "page": 1,
"perPage": -1, "perPage": -1,
"queryFilter": f"date >= {day_after_tomorrow.strftime('%Y-%m-%d')}", "queryFilter": f"date >= {day_after_tomorrow.strftime('%Y-%m-%d')}",
} }
@ -1025,7 +1025,7 @@ def test_pagination_filter_dates(api_client: TestClient, unique_user: TestUser):
assert len(response_json["items"]) == 0 assert len(response_json["items"]) == 0
params = { params = {
f"page": 1, "page": 1,
"perPage": -1, "perPage": -1,
"queryFilter": f"date > {day_after_tomorrow.strftime('%Y-%m-%d')}", "queryFilter": f"date > {day_after_tomorrow.strftime('%Y-%m-%d')}",
} }
@ -1077,20 +1077,20 @@ def test_pagination_filter_advanced_frontend_sort(database: AllRepositories, uni
CategorySave(group_id=unique_user.group_id, name=random_string(10)), CategorySave(group_id=unique_user.group_id, name=random_string(10)),
CategorySave(group_id=unique_user.group_id, name=random_string(10)), CategorySave(group_id=unique_user.group_id, name=random_string(10)),
] ]
category_1, category_2 = [database.categories.create(category) for category in categories] category_1, category_2 = (database.categories.create(category) for category in categories)
slug1, slug2 = (random_string(10) for _ in range(2)) slug1, slug2 = (random_string(10) for _ in range(2))
tags = [ tags = [
TagSave(group_id=unique_user.group_id, name=slug1, slug=slug1), TagSave(group_id=unique_user.group_id, name=slug1, slug=slug1),
TagSave(group_id=unique_user.group_id, name=slug2, slug=slug2), TagSave(group_id=unique_user.group_id, name=slug2, slug=slug2),
] ]
tag_1, tag_2 = [database.tags.create(tag) for tag in tags] tag_1, tag_2 = (database.tags.create(tag) for tag in tags)
tools = [ tools = [
RecipeToolSave(group_id=unique_user.group_id, name=random_string(10)), RecipeToolSave(group_id=unique_user.group_id, name=random_string(10)),
RecipeToolSave(group_id=unique_user.group_id, name=random_string(10)), RecipeToolSave(group_id=unique_user.group_id, name=random_string(10)),
] ]
tool_1, tool_2 = [database.tools.create(tool) for tool in tools] tool_1, tool_2 = (database.tools.create(tool) for tool in tools)
# Bootstrap the database with recipes # Bootstrap the database with recipes
slug = random_string() slug = random_string()

View file

@ -44,7 +44,7 @@ def search_recipes(database: AllRepositories, unique_local_group_id: str, unique
user_id=unique_local_user_id, user_id=unique_local_user_id,
group_id=unique_local_group_id, group_id=unique_local_group_id,
name="Steinbock Sloop", name="Steinbock Sloop",
description=f"My favorite horns are delicious", description="My favorite horns are delicious",
recipe_ingredient=[ recipe_ingredient=[
RecipeIngredient(note="alpine animal"), RecipeIngredient(note="alpine animal"),
], ],
@ -302,7 +302,7 @@ def test_recipe_repo_pagination_by_categories(database: AllRepositories, unique_
order_direction=OrderDirection.asc, order_direction=OrderDirection.asc,
) )
random_ordered = [] random_ordered = []
for i in range(5): for _ in range(5):
pagination_query.pagination_seed = str(datetime.now(timezone.utc)) pagination_query.pagination_seed = str(datetime.now(timezone.utc))
random_ordered.append(database.recipes.page_all(pagination_query, categories=[category_slug]).items) random_ordered.append(database.recipes.page_all(pagination_query, categories=[category_slug]).items)
assert not all(i == random_ordered[0] for i in random_ordered) assert not all(i == random_ordered[0] for i in random_ordered)
@ -395,7 +395,7 @@ def test_recipe_repo_pagination_by_tags(database: AllRepositories, unique_user:
order_direction=OrderDirection.asc, order_direction=OrderDirection.asc,
) )
random_ordered = [] random_ordered = []
for i in range(5): for _ in range(5):
pagination_query.pagination_seed = str(datetime.now(timezone.utc)) pagination_query.pagination_seed = str(datetime.now(timezone.utc))
random_ordered.append(database.recipes.page_all(pagination_query, tags=[tag_slug]).items) random_ordered.append(database.recipes.page_all(pagination_query, tags=[tag_slug]).items)
assert len(random_ordered[0]) == 15 assert len(random_ordered[0]) == 15
@ -491,7 +491,7 @@ def test_recipe_repo_pagination_by_tools(database: AllRepositories, unique_user:
order_direction=OrderDirection.asc, order_direction=OrderDirection.asc,
) )
random_ordered = [] random_ordered = []
for i in range(5): for _ in range(5):
pagination_query.pagination_seed = str(datetime.now(timezone.utc)) pagination_query.pagination_seed = str(datetime.now(timezone.utc))
random_ordered.append(database.recipes.page_all(pagination_query, tools=[tool_id]).items) random_ordered.append(database.recipes.page_all(pagination_query, tools=[tool_id]).items)
assert len(random_ordered[0]) == 15 assert len(random_ordered[0]) == 15
@ -575,7 +575,7 @@ def test_recipe_repo_pagination_by_foods(database: AllRepositories, unique_user:
order_direction=OrderDirection.asc, order_direction=OrderDirection.asc,
) )
random_ordered = [] random_ordered = []
for i in range(5): for _ in range(5):
pagination_query.pagination_seed = str(datetime.now(timezone.utc)) pagination_query.pagination_seed = str(datetime.now(timezone.utc))
random_ordered.append(database.recipes.page_all(pagination_query, foods=[food_id]).items) random_ordered.append(database.recipes.page_all(pagination_query, foods=[food_id]).items)
assert len(random_ordered[0]) == 15 assert len(random_ordered[0]) == 15

View file

@ -25,7 +25,7 @@ from mealie.services.backups_v2.backup_v2 import BackupV2
def dict_sorter(d: dict) -> Any: def dict_sorter(d: dict) -> Any:
possible_keys = {"created_at", "id"} possible_keys = {"created_at", "id"}
return next((d[key] for key in possible_keys if key in d and d[key]), 1) return next((d[key] for key in possible_keys if d.get(key)), 1)
# For Future Use # For Future Use
@ -68,7 +68,7 @@ def test_database_restore():
new_exporter = AlchemyExporter(settings.DB_URL) new_exporter = AlchemyExporter(settings.DB_URL)
snapshop_2 = new_exporter.dump() snapshop_2 = new_exporter.dump()
for s1, s2 in zip(snapshop_1, snapshop_2): for s1, s2 in zip(snapshop_1, snapshop_2, strict=False):
assert snapshop_1[s1].sort(key=dict_sorter) == snapshop_2[s2].sort(key=dict_sorter) assert snapshop_1[s1].sort(key=dict_sorter) == snapshop_2[s2].sort(key=dict_sorter)

View file

@ -34,7 +34,7 @@ def test_get_locked_users(database: AllRepositories, user_tuple: list[TestUser])
elif locked_user.id == user_2.id: elif locked_user.id == user_2.id:
assert locked_user.locked_at == user_2.locked_at assert locked_user.locked_at == user_2.locked_at
else: else:
assert False raise AssertionError()
# Cleanup # Cleanup
user_service.unlock_user(user_1) user_service.unlock_user(user_1)

View file

@ -145,7 +145,7 @@ def test_nlp_parser() -> None:
models: list[CRFIngredient] = convert_list_to_crf_model([x.input for x in test_ingredients]) models: list[CRFIngredient] = convert_list_to_crf_model([x.input for x in test_ingredients])
# Iterate over models and test_ingredients to gather # Iterate over models and test_ingredients to gather
for model, test_ingredient in zip(models, test_ingredients): for model, test_ingredient in zip(models, test_ingredients, strict=False):
assert round(float(sum(Fraction(s) for s in model.qty.split())), 3) == pytest.approx(test_ingredient.quantity) assert round(float(sum(Fraction(s) for s in model.qty.split())), 3) == pytest.approx(test_ingredient.quantity)
assert model.comment == test_ingredient.comments assert model.comment == test_ingredient.comments

View file

@ -27,9 +27,9 @@ class LdapConnMock:
self.name = name self.name = name
def simple_bind_s(self, dn, bind_pw): def simple_bind_s(self, dn, bind_pw):
if dn == "cn={}, {}".format(self.user, self.app_settings.LDAP_BASE_DN): if dn == f"cn={self.user}, {self.app_settings.LDAP_BASE_DN}":
valid_password = self.password valid_password = self.password
elif "cn={}, {}".format(self.query_bind, self.app_settings.LDAP_BASE_DN): elif f"cn={self.query_bind}, {self.app_settings.LDAP_BASE_DN}":
valid_password = self.query_password valid_password = self.query_password
if bind_pw == valid_password: if bind_pw == valid_password:
@ -42,7 +42,7 @@ class LdapConnMock:
if filter == self.app_settings.LDAP_ADMIN_FILTER: if filter == self.app_settings.LDAP_ADMIN_FILTER:
assert attrlist == [] assert attrlist == []
assert filter == self.app_settings.LDAP_ADMIN_FILTER assert filter == self.app_settings.LDAP_ADMIN_FILTER
assert dn == "cn={}, {}".format(self.user, self.app_settings.LDAP_BASE_DN) assert dn == f"cn={self.user}, {self.app_settings.LDAP_BASE_DN}"
assert scope == ldap.SCOPE_BASE assert scope == ldap.SCOPE_BASE
if not self.admin: if not self.admin:
@ -60,11 +60,9 @@ class LdapConnMock:
mail_attribute=self.app_settings.LDAP_MAIL_ATTRIBUTE, mail_attribute=self.app_settings.LDAP_MAIL_ATTRIBUTE,
input=self.user, input=self.user,
) )
search_filter = "(&(|({id_attribute}={input})({mail_attribute}={input})){filter})".format( search_filter = (
id_attribute=self.app_settings.LDAP_ID_ATTRIBUTE, f"(&(|({self.app_settings.LDAP_ID_ATTRIBUTE}={self.user})"
mail_attribute=self.app_settings.LDAP_MAIL_ATTRIBUTE, f"({self.app_settings.LDAP_MAIL_ATTRIBUTE}={self.user})){user_filter})"
input=self.user,
filter=user_filter,
) )
assert filter == search_filter assert filter == search_filter
assert dn == self.app_settings.LDAP_BASE_DN assert dn == self.app_settings.LDAP_BASE_DN
@ -72,7 +70,7 @@ class LdapConnMock:
return [ return [
( (
"cn={}, {}".format(self.user, self.app_settings.LDAP_BASE_DN), f"cn={self.user}, {self.app_settings.LDAP_BASE_DN}",
{ {
self.app_settings.LDAP_ID_ATTRIBUTE: [self.user.encode()], self.app_settings.LDAP_ID_ATTRIBUTE: [self.user.encode()],
self.app_settings.LDAP_NAME_ATTRIBUTE: [self.name.encode()], self.app_settings.LDAP_NAME_ATTRIBUTE: [self.name.encode()],