diff --git a/mealie/repos/repository_generic.py b/mealie/repos/repository_generic.py index 76abb55b7..85419090d 100644 --- a/mealie/repos/repository_generic.py +++ b/mealie/repos/repository_generic.py @@ -16,6 +16,7 @@ from mealie.db.models._model_base import SqlAlchemyBase from mealie.schema._mealie import MealieModel from mealie.schema.response.pagination import OrderDirection, PaginationBase, PaginationQuery from mealie.schema.response.query_filter import QueryFilter +from mealie.schema.response.query_search import SearchFilter Schema = TypeVar("Schema", bound=MealieModel) Model = TypeVar("Model", bound=SqlAlchemyBase) @@ -291,7 +292,7 @@ class RepositoryGeneric(Generic[Schema, Model]): q = self._query(override_schema=eff_schema).filter(attribute_name == attr_match) return [eff_schema.from_orm(x) for x in self.session.execute(q).scalars().all()] - def page_all(self, pagination: PaginationQuery, override=None) -> PaginationBase[Schema]: + def page_all(self, pagination: PaginationQuery, override=None, search: str | None = None) -> PaginationBase[Schema]: """ pagination is a method to interact with the filtered database table and return a paginated result using the PaginationBase that provides several data points that are needed to manage pagination @@ -302,12 +303,16 @@ class RepositoryGeneric(Generic[Schema, Model]): as the override, as the type system is not able to infer the result of this method. """ eff_schema = override or self.schema - + # Copy this, because calling methods (e.g. tests) might rely on it not getting mutated + pagination_result = pagination.copy() q = self._query(override_schema=eff_schema, with_options=False) fltr = self._filter_builder() q = q.filter_by(**fltr) - q, count, total_pages = self.add_pagination_to_query(q, pagination) + if search: + q = self.add_search_to_query(q, eff_schema, search) + + q, count, total_pages = self.add_pagination_to_query(q, pagination_result) # Apply options late, so they do not get used for counting q = q.options(*eff_schema.loader_options()) @@ -318,8 +323,8 @@ class RepositoryGeneric(Generic[Schema, Model]): self.session.rollback() raise e return PaginationBase( - page=pagination.page, - per_page=pagination.per_page, + page=pagination_result.page, + per_page=pagination_result.per_page, total=count, total_pages=total_pages, items=[eff_schema.from_orm(s) for s in data], @@ -392,3 +397,7 @@ class RepositoryGeneric(Generic[Schema, Model]): query = query.order_by(case_stmt) return query.limit(pagination.per_page).offset((pagination.page - 1) * pagination.per_page), count, total_pages + + def add_search_to_query(self, query: Select, schema: type[Schema], search: str) -> Select: + search_filter = SearchFilter(self.session, search, schema._normalize_search) + return search_filter.filter_query_by_search(query, schema, self.model) diff --git a/mealie/repos/repository_recipes.py b/mealie/repos/repository_recipes.py index a782daf47..6ab555d58 100644 --- a/mealie/repos/repository_recipes.py +++ b/mealie/repos/repository_recipes.py @@ -5,10 +5,9 @@ from uuid import UUID from pydantic import UUID4 from slugify import slugify -from sqlalchemy import Select, and_, desc, func, or_, select, text +from sqlalchemy import and_, func, select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import joinedload -from text_unidecode import unidecode from mealie.db.models.recipe.category import Category from mealie.db.models.recipe.ingredient import RecipeIngredientModel @@ -18,13 +17,7 @@ from mealie.db.models.recipe.tag import Tag from mealie.db.models.recipe.tool import Tool from mealie.schema.cookbook.cookbook import ReadCookBook from mealie.schema.recipe import Recipe -from mealie.schema.recipe.recipe import ( - RecipeCategory, - RecipePagination, - RecipeSummary, - RecipeTag, - RecipeTool, -) +from mealie.schema.recipe.recipe import RecipeCategory, RecipePagination, RecipeSummary, RecipeTag, RecipeTool from mealie.schema.recipe.recipe_category import CategoryBase, TagBase from mealie.schema.response.pagination import PaginationQuery @@ -151,98 +144,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]): additional_ids = self.session.execute(select(model.id).filter(model.slug.in_(slugs))).scalars().all() return ids + additional_ids - def _add_search_to_query(self, query: Select, search: str) -> Select: - """ - 0. fuzzy search (postgres only) and tokenized search are performed separately - 1. take search string and do a little pre-normalization - 2. look for internal quoted strings and keep them together as "literal" parts of the search - 3. remove special characters from each non-literal search string - 4. token search looks for any individual exact hit in name, description, and ingredients - 5. fuzzy search looks for trigram hits in name, description, and ingredients - 6. Sort order is determined by closeness to the recipe name - Should search also look at tags? - """ - - normalized_search = unidecode(search).lower().strip() - punctuation = "!\#$%&()*+,-./:;<=>?@[\\]^_`{|}~" # string.punctuation with ' & " removed - # keep quoted phrases together as literal portions of the search string - literal = False - quoted_regex = re.compile(r"""(["'])(?:(?=(\\?))\2.)*?\1""") # thank you stack exchange! - removequotes_regex = re.compile(r"""['"](.*)['"]""") - if quoted_regex.search(normalized_search): - literal = True - temp = normalized_search - quoted_search_list = [match.group() for match in quoted_regex.finditer(temp)] # all quoted strings - quoted_search_list = [removequotes_regex.sub("\\1", x) for x in quoted_search_list] # remove outer quotes - temp = quoted_regex.sub("", temp) # remove all quoted strings, leaving just non-quoted - temp = temp.translate( - str.maketrans(punctuation, " " * len(punctuation)) - ) # punctuation->spaces for splitting, but only on unquoted strings - unquoted_search_list = temp.split() # all unquoted strings - normalized_search_list = quoted_search_list + unquoted_search_list - else: - # - normalized_search = normalized_search.translate(str.maketrans(punctuation, " " * len(punctuation))) - normalized_search_list = normalized_search.split() - normalized_search_list = [x.strip() for x in normalized_search_list] # remove padding whitespace inside quotes - # I would prefer to just do this in the recipe_ingredient.any part of the main query, but it turns out - # that at least sqlite wont use indexes for that correctly anymore and takes a big hit, so prefiltering it is - if (self.session.get_bind().name == "postgresql") & (literal is False): # fuzzy search - ingredient_ids = ( - self.session.execute( - select(RecipeIngredientModel.id).filter( - or_( - RecipeIngredientModel.note_normalized.op("%>")(normalized_search), - RecipeIngredientModel.original_text_normalized.op("%>")(normalized_search), - ) - ) - ) - .scalars() - .all() - ) - else: # exact token search - ingredient_ids = ( - self.session.execute( - select(RecipeIngredientModel.id).filter( - or_( - *[RecipeIngredientModel.note_normalized.like(f"%{ns}%") for ns in normalized_search_list], - *[ - RecipeIngredientModel.original_text_normalized.like(f"%{ns}%") - for ns in normalized_search_list - ], - ) - ) - ) - .scalars() - .all() - ) - - if (self.session.get_bind().name == "postgresql") & (literal is False): # fuzzy search - # default = 0.7 is too strict for effective fuzzing - self.session.execute(text("set pg_trgm.word_similarity_threshold = 0.5;")) - q = query.filter( - or_( - RecipeModel.name_normalized.op("%>")(normalized_search), - RecipeModel.description_normalized.op("%>")(normalized_search), - RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)), - ) - ).order_by( # trigram ordering could be too slow on million record db, but is fine with thousands. - func.least( - RecipeModel.name_normalized.op("<->>")(normalized_search), - ) - ) - else: # exact token search - q = query.filter( - or_( - *[RecipeModel.name_normalized.like(f"%{ns}%") for ns in normalized_search_list], - *[RecipeModel.description_normalized.like(f"%{ns}%") for ns in normalized_search_list], - RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)), - ) - ).order_by(desc(RecipeModel.name_normalized.like(f"%{normalized_search}%"))) - - return q - - def page_all( + def page_all( # type: ignore self, pagination: PaginationQuery, override=None, @@ -299,7 +201,7 @@ class RepositoryRecipes(RepositoryGeneric[Recipe, RecipeModel]): ) q = q.filter(*filters) if search: - q = self._add_search_to_query(q, search) + q = self.add_search_to_query(q, self.schema, search) q, count, total_pages = self.add_pagination_to_query(q, pagination_result) diff --git a/mealie/routes/groups/controller_labels.py b/mealie/routes/groups/controller_labels.py index 2b0db90da..45f3c20d8 100644 --- a/mealie/routes/groups/controller_labels.py +++ b/mealie/routes/groups/controller_labels.py @@ -41,10 +41,11 @@ class MultiPurposeLabelsController(BaseUserController): return HttpRepo(self.repo, self.logger, self.registered_exceptions, self.t("generic.server-error")) @router.get("", response_model=MultiPurposeLabelPagination) - def get_all(self, q: PaginationQuery = Depends(PaginationQuery)): + def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None): response = self.repo.page_all( pagination=q, override=MultiPurposeLabelSummary, + search=search, ) response.set_pagination_guides(router.url_path_for("get_all"), q.dict()) diff --git a/mealie/routes/organizers/controller_categories.py b/mealie/routes/organizers/controller_categories.py index 3e584225b..a06b42b25 100644 --- a/mealie/routes/organizers/controller_categories.py +++ b/mealie/routes/organizers/controller_categories.py @@ -38,11 +38,12 @@ class RecipeCategoryController(BaseCrudController): return HttpRepo(self.repo, self.logger) @router.get("", response_model=RecipeCategoryPagination) - def get_all(self, q: PaginationQuery = Depends(PaginationQuery)): + def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None): """Returns a list of available categories in the database""" response = self.repo.page_all( pagination=q, override=RecipeCategory, + search=search, ) response.set_pagination_guides(router.url_path_for("get_all"), q.dict()) diff --git a/mealie/routes/organizers/controller_tags.py b/mealie/routes/organizers/controller_tags.py index 5f993890d..c258d3cf7 100644 --- a/mealie/routes/organizers/controller_tags.py +++ b/mealie/routes/organizers/controller_tags.py @@ -27,11 +27,12 @@ class TagController(BaseCrudController): return HttpRepo(self.repo, self.logger) @router.get("", response_model=RecipeTagPagination) - async def get_all(self, q: PaginationQuery = Depends(PaginationQuery)): + async def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None): """Returns a list of available tags in the database""" response = self.repo.page_all( pagination=q, override=RecipeTag, + search=search, ) response.set_pagination_guides(router.url_path_for("get_all"), q.dict()) diff --git a/mealie/routes/organizers/controller_tools.py b/mealie/routes/organizers/controller_tools.py index 3d42f51e3..c1fc424bf 100644 --- a/mealie/routes/organizers/controller_tools.py +++ b/mealie/routes/organizers/controller_tools.py @@ -25,10 +25,11 @@ class RecipeToolController(BaseUserController): return HttpRepo[RecipeToolCreate, RecipeTool, RecipeToolCreate](self.repo, self.logger) @router.get("", response_model=RecipeToolPagination) - def get_all(self, q: PaginationQuery = Depends(PaginationQuery)): + def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None): response = self.repo.page_all( pagination=q, override=RecipeTool, + search=search, ) response.set_pagination_guides(router.url_path_for("get_all"), q.dict()) diff --git a/mealie/routes/unit_and_foods/foods.py b/mealie/routes/unit_and_foods/foods.py index 95e63a2a8..f3fce5391 100644 --- a/mealie/routes/unit_and_foods/foods.py +++ b/mealie/routes/unit_and_foods/foods.py @@ -45,10 +45,11 @@ class IngredientFoodsController(BaseUserController): raise HTTPException(500, "Failed to merge foods") from e @router.get("", response_model=IngredientFoodPagination) - def get_all(self, q: PaginationQuery = Depends(PaginationQuery)): + def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None): response = self.repo.page_all( pagination=q, override=IngredientFood, + search=search, ) response.set_pagination_guides(router.url_path_for("get_all"), q.dict()) diff --git a/mealie/routes/unit_and_foods/units.py b/mealie/routes/unit_and_foods/units.py index 83f193a3d..0c6c3087f 100644 --- a/mealie/routes/unit_and_foods/units.py +++ b/mealie/routes/unit_and_foods/units.py @@ -45,10 +45,11 @@ class IngredientUnitsController(BaseUserController): raise HTTPException(500, "Failed to merge units") from e @router.get("", response_model=IngredientUnitPagination) - def get_all(self, q: PaginationQuery = Depends(PaginationQuery)): + def get_all(self, q: PaginationQuery = Depends(PaginationQuery), search: str | None = None): response = self.repo.page_all( pagination=q, override=IngredientUnit, + search=search, ) response.set_pagination_guides(router.url_path_for("get_all"), q.dict()) diff --git a/mealie/schema/_mealie/__init__.py b/mealie/schema/_mealie/__init__.py index bcca9f6ea..6712561a6 100644 --- a/mealie/schema/_mealie/__init__.py +++ b/mealie/schema/_mealie/__init__.py @@ -1,7 +1,8 @@ # This file is auto-generated by gen_schema_exports.py -from .mealie_model import HasUUID, MealieModel +from .mealie_model import HasUUID, MealieModel, SearchType __all__ = [ "HasUUID", "MealieModel", + "SearchType", ] diff --git a/mealie/schema/_mealie/mealie_model.py b/mealie/schema/_mealie/mealie_model.py index 1777458b5..979d31537 100644 --- a/mealie/schema/_mealie/mealie_model.py +++ b/mealie/schema/_mealie/mealie_model.py @@ -1,16 +1,34 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Protocol, TypeVar +from enum import Enum +from typing import ClassVar, Protocol, TypeVar from humps.main import camelize from pydantic import UUID4, BaseModel +from sqlalchemy import Select, desc, func, or_, text +from sqlalchemy.orm import InstrumentedAttribute, Session from sqlalchemy.orm.interfaces import LoaderOption +from mealie.db.models._model_base import SqlAlchemyBase + T = TypeVar("T", bound=BaseModel) +class SearchType(Enum): + fuzzy = "fuzzy" + tokenized = "tokenized" + + class MealieModel(BaseModel): + _fuzzy_similarity_threshold: ClassVar[float] = 0.5 + _normalize_search: ClassVar[bool] = False + _searchable_properties: ClassVar[list[str]] = [] + """ + Searchable properties for the search API. + The first property will be used for sorting (order_by) + """ + class Config: alias_generator = camelize allow_population_by_field_name = True @@ -59,6 +77,40 @@ class MealieModel(BaseModel): def loader_options(cls) -> list[LoaderOption]: return [] + @classmethod + def filter_search_query( + cls, + db_model: type[SqlAlchemyBase], + query: Select, + session: Session, + search_type: SearchType, + search: str, + search_list: list[str], + ) -> Select: + """ + Filters a search query based on model attributes + + Can be overridden to support a more advanced search + """ + + if not cls._searchable_properties: + raise AttributeError("Not Implemented") + + model_properties: list[InstrumentedAttribute] = [getattr(db_model, prop) for prop in cls._searchable_properties] + if search_type is SearchType.fuzzy: + session.execute(text(f"set pg_trgm.word_similarity_threshold = {cls._fuzzy_similarity_threshold};")) + filters = [prop.op("%>")(search) for prop in model_properties] + + # trigram ordering by the first searchable property + return query.filter(or_(*filters)).order_by(func.least(model_properties[0].op("<->>")(search))) + else: + filters = [] + for prop in model_properties: + filters.extend([prop.like(f"%{s}%") for s in search_list]) + + # order by how close the result is to the first searchable property + return query.filter(or_(*filters)).order_by(desc(model_properties[0].like(f"%{search}%"))) + class HasUUID(Protocol): id: UUID4 diff --git a/mealie/schema/labels/multi_purpose_label.py b/mealie/schema/labels/multi_purpose_label.py index 9faea5d98..eb41b23c9 100644 --- a/mealie/schema/labels/multi_purpose_label.py +++ b/mealie/schema/labels/multi_purpose_label.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import ClassVar + from pydantic import UUID4 from mealie.schema._mealie import MealieModel @@ -20,7 +22,7 @@ class MultiPurposeLabelUpdate(MultiPurposeLabelSave): class MultiPurposeLabelSummary(MultiPurposeLabelUpdate): - pass + _searchable_properties: ClassVar[list[str]] = ["name"] class Config: orm_mode = True @@ -31,14 +33,5 @@ class MultiPurposeLabelPagination(PaginationBase): class MultiPurposeLabelOut(MultiPurposeLabelUpdate): - # shopping_list_items: list[ShoppingListItemOut] = [] - # foods: list[IngredientFood] = [] - class Config: orm_mode = True - - -# from mealie.schema.recipe.recipe_ingredient import IngredientFood -# from mealie.schema.group.group_shopping_list import ShoppingListItemOut - -# MultiPurposeLabelOut.update_forward_refs() diff --git a/mealie/schema/recipe/recipe.py b/mealie/schema/recipe/recipe.py index 22c812e4e..ba49e00cd 100644 --- a/mealie/schema/recipe/recipe.py +++ b/mealie/schema/recipe/recipe.py @@ -2,16 +2,17 @@ from __future__ import annotations import datetime from pathlib import Path -from typing import Any +from typing import Any, ClassVar from uuid import uuid4 from pydantic import UUID4, BaseModel, Field, validator from slugify import slugify -from sqlalchemy.orm import joinedload, selectinload +from sqlalchemy import Select, desc, func, or_, select, text +from sqlalchemy.orm import Session, joinedload, selectinload from sqlalchemy.orm.interfaces import LoaderOption from mealie.core.config import get_app_dirs -from mealie.schema._mealie import MealieModel +from mealie.schema._mealie import MealieModel, SearchType from mealie.schema.response.pagination import PaginationBase from ...db.models.recipe import ( @@ -37,6 +38,8 @@ class RecipeTag(MealieModel): name: str slug: str + _searchable_properties: ClassVar[list[str]] = ["name"] + class Config: orm_mode = True @@ -78,6 +81,7 @@ class CreateRecipe(MealieModel): class RecipeSummary(MealieModel): id: UUID4 | None + _normalize_search: ClassVar[bool] = True user_id: UUID4 = Field(default_factory=uuid4) group_id: UUID4 = Field(default_factory=uuid4) @@ -259,6 +263,69 @@ class Recipe(RecipeSummary): selectinload(RecipeModel.notes), ] + @classmethod + def filter_search_query( + cls, db_model, query: Select, session: Session, search_type: SearchType, search: str, search_list: list[str] + ) -> Select: + """ + 1. token search looks for any individual exact hit in name, description, and ingredients + 2. fuzzy search looks for trigram hits in name, description, and ingredients + 3. Sort order is determined by closeness to the recipe name + Should search also look at tags? + """ + + if search_type is SearchType.fuzzy: + # I would prefer to just do this in the recipe_ingredient.any part of the main query, + # but it turns out that at least sqlite wont use indexes for that correctly anymore and + # takes a big hit, so prefiltering it is + ingredient_ids = ( + session.execute( + select(RecipeIngredientModel.id).filter( + or_( + RecipeIngredientModel.note_normalized.op("%>")(search), + RecipeIngredientModel.original_text_normalized.op("%>")(search), + ) + ) + ) + .scalars() + .all() + ) + + session.execute(text(f"set pg_trgm.word_similarity_threshold = {cls._fuzzy_similarity_threshold};")) + return query.filter( + or_( + RecipeModel.name_normalized.op("%>")(search), + RecipeModel.description_normalized.op("%>")(search), + RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)), + ) + ).order_by( # trigram ordering could be too slow on million record db, but is fine with thousands. + func.least( + RecipeModel.name_normalized.op("<->>")(search), + ) + ) + + else: + ingredient_ids = ( + session.execute( + select(RecipeIngredientModel.id).filter( + or_( + *[RecipeIngredientModel.note_normalized.like(f"%{ns}%") for ns in search_list], + *[RecipeIngredientModel.original_text_normalized.like(f"%{ns}%") for ns in search_list], + ) + ) + ) + .scalars() + .all() + ) + + return query.filter( + or_( + *[RecipeModel.name_normalized.like(f"%{ns}%") for ns in search_list], + *[RecipeModel.description_normalized.like(f"%{ns}%") for ns in search_list], + RecipeModel.recipe_ingredient.any(RecipeIngredientModel.id.in_(ingredient_ids)), + ) + ).order_by(desc(RecipeModel.name_normalized.like(f"%{search}%"))) + class RecipeLastMade(BaseModel): timestamp: datetime.datetime diff --git a/mealie/schema/recipe/recipe_ingredient.py b/mealie/schema/recipe/recipe_ingredient.py index ef4facb60..36006b9c1 100644 --- a/mealie/schema/recipe/recipe_ingredient.py +++ b/mealie/schema/recipe/recipe_ingredient.py @@ -3,6 +3,7 @@ from __future__ import annotations import datetime import enum from fractions import Fraction +from typing import ClassVar from uuid import UUID, uuid4 from pydantic import UUID4, Field, validator @@ -50,6 +51,8 @@ class IngredientFood(CreateIngredientFood): created_at: datetime.datetime | None update_at: datetime.datetime | None + _searchable_properties: ClassVar[list[str]] = ["name", "description"] + class Config: orm_mode = True getter_dict = ExtrasGetterDict @@ -78,6 +81,8 @@ class IngredientUnit(CreateIngredientUnit): created_at: datetime.datetime | None update_at: datetime.datetime | None + _searchable_properties: ClassVar[list[str]] = ["name", "abbreviation", "description"] + class Config: orm_mode = True diff --git a/mealie/schema/response/query_search.py b/mealie/schema/response/query_search.py new file mode 100644 index 000000000..21c390b39 --- /dev/null +++ b/mealie/schema/response/query_search.py @@ -0,0 +1,67 @@ +import re + +from sqlalchemy import Select +from sqlalchemy.orm import Session +from text_unidecode import unidecode + +from ...db.models._model_base import SqlAlchemyBase +from .._mealie import MealieModel, SearchType + + +class SearchFilter: + """ + 0. fuzzy search (postgres only) and tokenized search are performed separately + 1. take search string and do a little pre-normalization + 2. look for internal quoted strings and keep them together as "literal" parts of the search + 3. remove special characters from each non-literal search string + """ + + punctuation = "!\#$%&()*+,-./:;<=>?@[\\]^_`{|}~" # string.punctuation with ' & " removed + quoted_regex = re.compile(r"""(["'])(?:(?=(\\?))\2.)*?\1""") + remove_quotes_regex = re.compile(r"""['"](.*)['"]""") + + @classmethod + def _normalize_search(cls, search: str, normalize_characters: bool) -> str: + search = search.translate(str.maketrans(cls.punctuation, " " * len(cls.punctuation))) + + if normalize_characters: + search = unidecode(search).lower().strip() + else: + search = search.strip() + + return search + + @classmethod + def _build_search_list(cls, search: str) -> list[str]: + if cls.quoted_regex.search(search): + # all quoted strings + quoted_search_list = [match.group() for match in cls.quoted_regex.finditer(search)] + + # remove outer quotes + quoted_search_list = [cls.remove_quotes_regex.sub("\\1", x) for x in quoted_search_list] + + # punctuation->spaces for splitting, but only on unquoted strings + search = cls.quoted_regex.sub("", search) # remove all quoted strings, leaving just non-quoted + search = search.translate(str.maketrans(cls.punctuation, " " * len(cls.punctuation))) + + # all unquoted strings + unquoted_search_list = search.split() + search_list = quoted_search_list + unquoted_search_list + else: + search_list = search.translate(str.maketrans(cls.punctuation, " " * len(cls.punctuation))).split() + + # remove padding whitespace inside quotes + return [x.strip() for x in search_list] + + def __init__(self, session: Session, search: str, normalize_characters: bool = False) -> None: + if session.get_bind().name != "postgresql" or self.quoted_regex.search(search.strip()): + self.search_type = SearchType.tokenized + else: + self.search_type = SearchType.fuzzy + + self.session = session + self.search = self._normalize_search(search, normalize_characters) + self.search_list = self._build_search_list(self.search) + + def filter_query_by_search(self, query: Select, schema: type[MealieModel], model: type[SqlAlchemyBase]) -> Select: + return schema.filter_search_query(model, query, self.session, self.search_type, self.search, self.search_list) diff --git a/tests/unit_tests/repository_tests/test_recipe_repository.py b/tests/unit_tests/repository_tests/test_recipe_repository.py index 9cce9dd04..4ebb502e1 100644 --- a/tests/unit_tests/repository_tests/test_recipe_repository.py +++ b/tests/unit_tests/repository_tests/test_recipe_repository.py @@ -1,17 +1,108 @@ from datetime import datetime from typing import cast +import pytest + from mealie.repos.repository_factory import AllRepositories from mealie.repos.repository_recipes import RepositoryRecipes -from mealie.schema.recipe import RecipeIngredient, SaveIngredientFood, RecipeStep +from mealie.schema.recipe import RecipeIngredient, SaveIngredientFood from mealie.schema.recipe.recipe import Recipe, RecipeCategory, RecipeSummary from mealie.schema.recipe.recipe_category import CategoryOut, CategorySave, TagSave from mealie.schema.recipe.recipe_tool import RecipeToolSave from mealie.schema.response import OrderDirection, PaginationQuery -from tests.utils.factories import random_string +from mealie.schema.user.user import GroupBase +from tests.utils.factories import random_email, random_string from tests.utils.fixture_schemas import TestUser +@pytest.fixture() +def unique_local_group_id(database: AllRepositories) -> str: + return str(database.groups.create(GroupBase(name=random_string())).id) + + +@pytest.fixture() +def unique_local_user_id(database: AllRepositories, unique_local_group_id: str) -> str: + return str( + database.users.create( + { + "username": random_string(), + "email": random_email(), + "group_id": unique_local_group_id, + "full_name": random_string(), + "password": random_string(), + "admin": False, + } + ).id + ) + + +@pytest.fixture() +def search_recipes(database: AllRepositories, unique_local_group_id: str, unique_local_user_id: str) -> list[Recipe]: + recipes = [ + Recipe( + user_id=unique_local_user_id, + group_id=unique_local_group_id, + name="Steinbock Sloop", + description=f"My favorite horns are delicious", + recipe_ingredient=[ + RecipeIngredient(note="alpine animal"), + ], + ), + Recipe( + user_id=unique_local_user_id, + group_id=unique_local_group_id, + name="Fiddlehead Fern Stir Fry", + recipe_ingredient=[ + RecipeIngredient(note="moss"), + ], + ), + Recipe( + user_id=unique_local_user_id, + group_id=unique_local_group_id, + name="Animal Sloop", + ), + # Test diacritics + Recipe( + user_id=unique_local_user_id, + group_id=unique_local_group_id, + name="Rátàtôuile", + ), + # Add a bunch of recipes for stable randomization + Recipe( + user_id=unique_local_user_id, + group_id=unique_local_group_id, + name=f"{random_string(10)} soup", + ), + Recipe( + user_id=unique_local_user_id, + group_id=unique_local_group_id, + name=f"{random_string(10)} soup", + ), + Recipe( + user_id=unique_local_user_id, + group_id=unique_local_group_id, + name=f"{random_string(10)} soup", + ), + Recipe( + user_id=unique_local_user_id, + group_id=unique_local_group_id, + name=f"{random_string(10)} soup", + ), + Recipe( + user_id=unique_local_user_id, + group_id=unique_local_group_id, + name=f"{random_string(10)} soup", + ), + Recipe( + user_id=unique_local_user_id, + group_id=unique_local_group_id, + name=f"{random_string(10)} soup", + ), + ] + + return database.recipes.create_many(recipes) + + def test_recipe_repo_get_by_categories_basic(database: AllRepositories, unique_user: TestUser): # Bootstrap the database with categories slug1, slug2, slug3 = (random_string(10) for _ in range(3)) @@ -112,7 +203,7 @@ def test_recipe_repo_get_by_categories_multi(database: AllRepositories, unique_u database.recipes.create(recipe) # Get all recipes by both categories - repo: RepositoryRecipes = database.recipes.by_group(unique_user.group_id) # type: ignore + repo: RepositoryRecipes = database.recipes.by_group(unique_local_group_id) # type: ignore by_category = repo.get_by_categories(cast(list[RecipeCategory], created_categories)) assert len(by_category) == 10 @@ -490,129 +581,72 @@ def test_recipe_repo_pagination_by_foods(database: AllRepositories, unique_user: assert not all(i == random_ordered[0] for i in random_ordered) -def test_recipe_repo_search(database: AllRepositories, unique_user: TestUser): - recipes = [ - Recipe( - user_id=unique_user.user_id, - group_id=unique_user.group_id, - name="Steinbock Sloop", - description=f"My favorite horns are delicious", - recipe_ingredient=[ - RecipeIngredient(note="alpine animal"), - ], - ), - Recipe( - user_id=unique_user.user_id, - group_id=unique_user.group_id, - name="Fiddlehead Fern Stir Fry", - recipe_ingredient=[ - RecipeIngredient(note="moss"), - ], - ), - Recipe( - user_id=unique_user.user_id, - group_id=unique_user.group_id, - name="Animal Sloop", - ), - # Test diacritics - Recipe( - user_id=unique_user.user_id, - group_id=unique_user.group_id, - name="Rátàtôuile", - ), - # Add a bunch of recipes for stable randomization - Recipe( - user_id=unique_user.user_id, - group_id=unique_user.group_id, - name=f"{random_string(10)} soup", - ), - Recipe( - user_id=unique_user.user_id, - group_id=unique_user.group_id, - name=f"{random_string(10)} soup", - ), - Recipe( - user_id=unique_user.user_id, - group_id=unique_user.group_id, - name=f"{random_string(10)} soup", - ), - Recipe( - user_id=unique_user.user_id, - group_id=unique_user.group_id, - name=f"{random_string(10)} soup", - ), - Recipe( - user_id=unique_user.user_id, - group_id=unique_user.group_id, - name=f"{random_string(10)} soup", - ), - Recipe( - user_id=unique_user.user_id, - group_id=unique_user.group_id, - name=f"{random_string(10)} soup", - ), - ] +@pytest.mark.parametrize( + "search, expected_names", + [ + (random_string(), []), + ("Steinbock", ["Steinbock Sloop"]), + ("horns", ["Steinbock Sloop"]), + ("moss", ["Fiddlehead Fern Stir Fry"]), + ('"Animal Sloop"', ["Animal Sloop"]), + ("animal-sloop", ["Animal Sloop"]), + ("ratat", ["Rátàtôuile"]), + ("delicious horns", ["Steinbock Sloop"]), + ], + ids=[ + "no_match", + "search_by_title", + "search_by_description", + "search_by_ingredient", + "literal_search", + "special_character_removal", + "normalization", + "token_separation", + ], +) +def test_basic_recipe_search( + search: str, + expected_names: list[str], + database: AllRepositories, + search_recipes: list[Recipe], # required so database is populated + unique_local_group_id: str, +): + repo = database.recipes.by_group(unique_local_group_id) # type: ignore + pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc) + results = repo.page_all(pagination, search=search).items - for recipe in recipes: - database.recipes.create(recipe) + if len(expected_names) == 0: + assert len(results) == 0 + else: + # if more results are returned, that's acceptable, as long as they are ranked correctly + assert len(results) >= len(expected_names) + for recipe, name in zip(results, expected_names, strict=False): + assert recipe.name == name - pagination_query = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc) - # No hits - empty_result = database.recipes.page_all(pagination_query, search=random_string(10)).items - assert len(empty_result) == 0 +def test_fuzzy_recipe_search( + database: AllRepositories, + search_recipes: list[Recipe], # required so database is populated + unique_local_group_id: str, +): + # this only works on postgres + if database.session.get_bind().name != "postgresql": + return - # Search by title - title_result = database.recipes.page_all(pagination_query, search="Steinbock").items - assert len(title_result) == 1 - assert title_result[0].name == "Steinbock Sloop" + repo = database.recipes.by_group(unique_local_group_id) # type: ignore + pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc) + results = repo.page_all(pagination, search="Steinbuck").items - # Search by description - description_result = database.recipes.page_all(pagination_query, search="horns").items - assert len(description_result) == 1 - assert description_result[0].name == "Steinbock Sloop" + assert results and results[0].name == "Steinbock Sloop" - # Search by ingredient - ingredient_result = database.recipes.page_all(pagination_query, search="moss").items - assert len(ingredient_result) == 1 - assert ingredient_result[0].name == "Fiddlehead Fern Stir Fry" - # Make sure title matches are ordered in front - ordered_result = database.recipes.page_all(pagination_query, search="animal sloop").items - assert len(ordered_result) == 2 - assert ordered_result[0].name == "Animal Sloop" - assert ordered_result[1].name == "Steinbock Sloop" - - # Test literal search - literal_result = database.recipes.page_all(pagination_query, search='"Animal Sloop"').items - assert len(literal_result) == 1 - assert literal_result[0].name == "Animal Sloop" - - # Test special character removal from non-literal searches - character_result = database.recipes.page_all(pagination_query, search="animal-sloop").items - assert len(character_result) == 2 - assert character_result[0].name == "Animal Sloop" - assert character_result[1].name == "Steinbock Sloop" - - # Test string normalization - normalized_result = database.recipes.page_all(pagination_query, search="ratat").items - print([r.name for r in normalized_result]) - assert len(normalized_result) == 1 - assert normalized_result[0].name == "Rátàtôuile" - - # Test token separation - token_result = database.recipes.page_all(pagination_query, search="delicious horns").items - assert len(token_result) == 1 - assert token_result[0].name == "Steinbock Sloop" - - # Test fuzzy search - if database.session.get_bind().name == "postgresql": - fuzzy_result = database.recipes.page_all(pagination_query, search="Steinbuck").items - assert len(fuzzy_result) == 1 - assert fuzzy_result[0].name == "Steinbock Sloop" - - # Test random ordering with search - pagination_query = PaginationQuery( +def test_random_order_recipe_search( + database: AllRepositories, + search_recipes: list[Recipe], # required so database is populated + unique_local_group_id: str, +): + repo = database.recipes.by_group(unique_local_group_id) # type: ignore + pagination = PaginationQuery( page=1, per_page=-1, order_by="random", @@ -620,7 +654,7 @@ def test_recipe_repo_search(database: AllRepositories, unique_user: TestUser): order_direction=OrderDirection.asc, ) random_ordered = [] - for i in range(5): - pagination_query.pagination_seed = str(datetime.now()) - random_ordered.append(database.recipes.page_all(pagination_query, search="soup").items) + for _ in range(5): + pagination.pagination_seed = str(datetime.now()) + random_ordered.append(repo.page_all(pagination, search="soup").items) assert not all(i == random_ordered[0] for i in random_ordered) diff --git a/tests/unit_tests/repository_tests/test_search.py b/tests/unit_tests/repository_tests/test_search.py new file mode 100644 index 000000000..f79398f8a --- /dev/null +++ b/tests/unit_tests/repository_tests/test_search.py @@ -0,0 +1,135 @@ +from datetime import datetime + +import pytest + +from mealie.repos.repository_factory import AllRepositories +from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit +from mealie.schema.response.pagination import OrderDirection, PaginationQuery +from mealie.schema.user.user import GroupBase +from tests.utils.factories import random_int, random_string + + +@pytest.fixture() +def unique_local_group_id(database: AllRepositories) -> str: + return str(database.groups.create(GroupBase(name=random_string())).id) + + +@pytest.fixture() +def search_units(database: AllRepositories, unique_local_group_id: str) -> list[IngredientUnit]: + units = [ + SaveIngredientUnit( + group_id=unique_local_group_id, + name="Tea Spoon", + abbreviation="tsp", + ), + SaveIngredientUnit( + group_id=unique_local_group_id, + name="Table Spoon", + description="unique description", + abbreviation="tbsp", + ), + SaveIngredientUnit( + group_id=unique_local_group_id, + name="Cup", + description="A bucket that's full", + ), + SaveIngredientUnit( + group_id=unique_local_group_id, + name="Píñch", + ), + SaveIngredientUnit( + group_id=unique_local_group_id, + name="Unit with a very cool name", + ), + SaveIngredientUnit( + group_id=unique_local_group_id, + name="Unit with a pretty cool name", + ), + ] + + # Add a bunch of units for stable randomization + units.extend( + [ + SaveIngredientUnit(group_id=unique_local_group_id, name=f"{random_string()} unit") + for _ in range(random_int(12, 20)) + ] + ) + + return database.ingredient_units.create_many(units) + + +@pytest.mark.parametrize( + "search, expected_names", + [ + (random_string(), []), + ("Cup", ["Cup"]), + ("tbsp", ["Table Spoon"]), + ("unique description", ["Table Spoon"]), + ("very cool name", ["Unit with a very cool name", "Unit with a pretty cool name"]), + ('"Tea Spoon"', ["Tea Spoon"]), + ("full bucket", ["Cup"]), + ], + ids=[ + "no_match", + "search_by_name", + "search_by_unit", + "search_by_description", + "match_order", + "literal_search", + "token_separation", + ], +) +def test_basic_search( + search: str, + expected_names: list[str], + database: AllRepositories, + search_units: list[IngredientUnit], # required so database is populated + unique_local_group_id: str, +): + repo = database.ingredient_units.by_group(unique_local_group_id) + pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc) + results = repo.page_all(pagination, search=search).items + + if len(expected_names) == 0: + assert len(results) == 0 + else: + # if more results are returned, that's acceptable, as long as they are ranked correctly + assert len(results) >= len(expected_names) + for unit, name in zip(results, expected_names, strict=False): + assert unit.name == name + + +def test_fuzzy_search( + database: AllRepositories, + search_units: list[IngredientUnit], # required so database is populated + unique_local_group_id: str, +): + # this only works on postgres + if database.session.get_bind().name != "postgresql": + return + + repo = database.ingredient_units.by_group(unique_local_group_id) + pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc) + results = repo.page_all(pagination, search="unique decsription").items + + assert results and results[0].name == "Table Spoon" + + +def test_random_order_search( + database: AllRepositories, + search_units: list[IngredientUnit], # required so database is populated + unique_local_group_id: str, +): + repo = database.ingredient_units.by_group(unique_local_group_id) + pagination = PaginationQuery( + page=1, + per_page=-1, + order_by="random", + pagination_seed=str(datetime.now()), + order_direction=OrderDirection.asc, + ) + random_ordered = [] + for _ in range(5): + pagination.pagination_seed = str(datetime.now()) + random_ordered.append(repo.page_all(pagination, search="unit").items) + assert not all(i == random_ordered[0] for i in random_ordered)