mirror of
https://github.com/mealie-recipes/mealie.git
synced 2025-08-02 20:15:24 +02:00
feat: Generalize Search to Other Models (#2472)
* generalized search logic to SearchFilter * added default search behavior for all models * fix for schema overrides * added search support to several models * fix for label search * tests and fixes * add config for normalizing characters * dramatically simplified search tests * bark bark * fix normalization bug * tweaked tests * maybe this time? --------- Co-authored-by: Hayden <64056131+hay-kot@users.noreply.github.com>
This commit is contained in:
parent
76ae0bafc7
commit
99372aa2b6
16 changed files with 521 additions and 250 deletions
|
@ -1,17 +1,108 @@
|
|||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
from mealie.repos.repository_factory import AllRepositories
|
||||
from mealie.repos.repository_recipes import RepositoryRecipes
|
||||
from mealie.schema.recipe import RecipeIngredient, SaveIngredientFood, RecipeStep
|
||||
from mealie.schema.recipe import RecipeIngredient, SaveIngredientFood
|
||||
from mealie.schema.recipe.recipe import Recipe, RecipeCategory, RecipeSummary
|
||||
from mealie.schema.recipe.recipe_category import CategoryOut, CategorySave, TagSave
|
||||
from mealie.schema.recipe.recipe_tool import RecipeToolSave
|
||||
from mealie.schema.response import OrderDirection, PaginationQuery
|
||||
from tests.utils.factories import random_string
|
||||
from mealie.schema.user.user import GroupBase
|
||||
from tests.utils.factories import random_email, random_string
|
||||
from tests.utils.fixture_schemas import TestUser
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def unique_local_group_id(database: AllRepositories) -> str:
|
||||
return str(database.groups.create(GroupBase(name=random_string())).id)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def unique_local_user_id(database: AllRepositories, unique_local_group_id: str) -> str:
|
||||
return str(
|
||||
database.users.create(
|
||||
{
|
||||
"username": random_string(),
|
||||
"email": random_email(),
|
||||
"group_id": unique_local_group_id,
|
||||
"full_name": random_string(),
|
||||
"password": random_string(),
|
||||
"admin": False,
|
||||
}
|
||||
).id
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def search_recipes(database: AllRepositories, unique_local_group_id: str, unique_local_user_id: str) -> list[Recipe]:
|
||||
recipes = [
|
||||
Recipe(
|
||||
user_id=unique_local_user_id,
|
||||
group_id=unique_local_group_id,
|
||||
name="Steinbock Sloop",
|
||||
description=f"My favorite horns are delicious",
|
||||
recipe_ingredient=[
|
||||
RecipeIngredient(note="alpine animal"),
|
||||
],
|
||||
),
|
||||
Recipe(
|
||||
user_id=unique_local_user_id,
|
||||
group_id=unique_local_group_id,
|
||||
name="Fiddlehead Fern Stir Fry",
|
||||
recipe_ingredient=[
|
||||
RecipeIngredient(note="moss"),
|
||||
],
|
||||
),
|
||||
Recipe(
|
||||
user_id=unique_local_user_id,
|
||||
group_id=unique_local_group_id,
|
||||
name="Animal Sloop",
|
||||
),
|
||||
# Test diacritics
|
||||
Recipe(
|
||||
user_id=unique_local_user_id,
|
||||
group_id=unique_local_group_id,
|
||||
name="Rátàtôuile",
|
||||
),
|
||||
# Add a bunch of recipes for stable randomization
|
||||
Recipe(
|
||||
user_id=unique_local_user_id,
|
||||
group_id=unique_local_group_id,
|
||||
name=f"{random_string(10)} soup",
|
||||
),
|
||||
Recipe(
|
||||
user_id=unique_local_user_id,
|
||||
group_id=unique_local_group_id,
|
||||
name=f"{random_string(10)} soup",
|
||||
),
|
||||
Recipe(
|
||||
user_id=unique_local_user_id,
|
||||
group_id=unique_local_group_id,
|
||||
name=f"{random_string(10)} soup",
|
||||
),
|
||||
Recipe(
|
||||
user_id=unique_local_user_id,
|
||||
group_id=unique_local_group_id,
|
||||
name=f"{random_string(10)} soup",
|
||||
),
|
||||
Recipe(
|
||||
user_id=unique_local_user_id,
|
||||
group_id=unique_local_group_id,
|
||||
name=f"{random_string(10)} soup",
|
||||
),
|
||||
Recipe(
|
||||
user_id=unique_local_user_id,
|
||||
group_id=unique_local_group_id,
|
||||
name=f"{random_string(10)} soup",
|
||||
),
|
||||
]
|
||||
|
||||
return database.recipes.create_many(recipes)
|
||||
|
||||
|
||||
def test_recipe_repo_get_by_categories_basic(database: AllRepositories, unique_user: TestUser):
|
||||
# Bootstrap the database with categories
|
||||
slug1, slug2, slug3 = (random_string(10) for _ in range(3))
|
||||
|
@ -112,7 +203,7 @@ def test_recipe_repo_get_by_categories_multi(database: AllRepositories, unique_u
|
|||
database.recipes.create(recipe)
|
||||
|
||||
# Get all recipes by both categories
|
||||
repo: RepositoryRecipes = database.recipes.by_group(unique_user.group_id) # type: ignore
|
||||
repo: RepositoryRecipes = database.recipes.by_group(unique_local_group_id) # type: ignore
|
||||
by_category = repo.get_by_categories(cast(list[RecipeCategory], created_categories))
|
||||
|
||||
assert len(by_category) == 10
|
||||
|
@ -490,129 +581,72 @@ def test_recipe_repo_pagination_by_foods(database: AllRepositories, unique_user:
|
|||
assert not all(i == random_ordered[0] for i in random_ordered)
|
||||
|
||||
|
||||
def test_recipe_repo_search(database: AllRepositories, unique_user: TestUser):
|
||||
recipes = [
|
||||
Recipe(
|
||||
user_id=unique_user.user_id,
|
||||
group_id=unique_user.group_id,
|
||||
name="Steinbock Sloop",
|
||||
description=f"My favorite horns are delicious",
|
||||
recipe_ingredient=[
|
||||
RecipeIngredient(note="alpine animal"),
|
||||
],
|
||||
),
|
||||
Recipe(
|
||||
user_id=unique_user.user_id,
|
||||
group_id=unique_user.group_id,
|
||||
name="Fiddlehead Fern Stir Fry",
|
||||
recipe_ingredient=[
|
||||
RecipeIngredient(note="moss"),
|
||||
],
|
||||
),
|
||||
Recipe(
|
||||
user_id=unique_user.user_id,
|
||||
group_id=unique_user.group_id,
|
||||
name="Animal Sloop",
|
||||
),
|
||||
# Test diacritics
|
||||
Recipe(
|
||||
user_id=unique_user.user_id,
|
||||
group_id=unique_user.group_id,
|
||||
name="Rátàtôuile",
|
||||
),
|
||||
# Add a bunch of recipes for stable randomization
|
||||
Recipe(
|
||||
user_id=unique_user.user_id,
|
||||
group_id=unique_user.group_id,
|
||||
name=f"{random_string(10)} soup",
|
||||
),
|
||||
Recipe(
|
||||
user_id=unique_user.user_id,
|
||||
group_id=unique_user.group_id,
|
||||
name=f"{random_string(10)} soup",
|
||||
),
|
||||
Recipe(
|
||||
user_id=unique_user.user_id,
|
||||
group_id=unique_user.group_id,
|
||||
name=f"{random_string(10)} soup",
|
||||
),
|
||||
Recipe(
|
||||
user_id=unique_user.user_id,
|
||||
group_id=unique_user.group_id,
|
||||
name=f"{random_string(10)} soup",
|
||||
),
|
||||
Recipe(
|
||||
user_id=unique_user.user_id,
|
||||
group_id=unique_user.group_id,
|
||||
name=f"{random_string(10)} soup",
|
||||
),
|
||||
Recipe(
|
||||
user_id=unique_user.user_id,
|
||||
group_id=unique_user.group_id,
|
||||
name=f"{random_string(10)} soup",
|
||||
),
|
||||
]
|
||||
@pytest.mark.parametrize(
|
||||
"search, expected_names",
|
||||
[
|
||||
(random_string(), []),
|
||||
("Steinbock", ["Steinbock Sloop"]),
|
||||
("horns", ["Steinbock Sloop"]),
|
||||
("moss", ["Fiddlehead Fern Stir Fry"]),
|
||||
('"Animal Sloop"', ["Animal Sloop"]),
|
||||
("animal-sloop", ["Animal Sloop"]),
|
||||
("ratat", ["Rátàtôuile"]),
|
||||
("delicious horns", ["Steinbock Sloop"]),
|
||||
],
|
||||
ids=[
|
||||
"no_match",
|
||||
"search_by_title",
|
||||
"search_by_description",
|
||||
"search_by_ingredient",
|
||||
"literal_search",
|
||||
"special_character_removal",
|
||||
"normalization",
|
||||
"token_separation",
|
||||
],
|
||||
)
|
||||
def test_basic_recipe_search(
|
||||
search: str,
|
||||
expected_names: list[str],
|
||||
database: AllRepositories,
|
||||
search_recipes: list[Recipe], # required so database is populated
|
||||
unique_local_group_id: str,
|
||||
):
|
||||
repo = database.recipes.by_group(unique_local_group_id) # type: ignore
|
||||
pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
|
||||
results = repo.page_all(pagination, search=search).items
|
||||
|
||||
for recipe in recipes:
|
||||
database.recipes.create(recipe)
|
||||
if len(expected_names) == 0:
|
||||
assert len(results) == 0
|
||||
else:
|
||||
# if more results are returned, that's acceptable, as long as they are ranked correctly
|
||||
assert len(results) >= len(expected_names)
|
||||
for recipe, name in zip(results, expected_names, strict=False):
|
||||
assert recipe.name == name
|
||||
|
||||
pagination_query = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
|
||||
|
||||
# No hits
|
||||
empty_result = database.recipes.page_all(pagination_query, search=random_string(10)).items
|
||||
assert len(empty_result) == 0
|
||||
def test_fuzzy_recipe_search(
|
||||
database: AllRepositories,
|
||||
search_recipes: list[Recipe], # required so database is populated
|
||||
unique_local_group_id: str,
|
||||
):
|
||||
# this only works on postgres
|
||||
if database.session.get_bind().name != "postgresql":
|
||||
return
|
||||
|
||||
# Search by title
|
||||
title_result = database.recipes.page_all(pagination_query, search="Steinbock").items
|
||||
assert len(title_result) == 1
|
||||
assert title_result[0].name == "Steinbock Sloop"
|
||||
repo = database.recipes.by_group(unique_local_group_id) # type: ignore
|
||||
pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
|
||||
results = repo.page_all(pagination, search="Steinbuck").items
|
||||
|
||||
# Search by description
|
||||
description_result = database.recipes.page_all(pagination_query, search="horns").items
|
||||
assert len(description_result) == 1
|
||||
assert description_result[0].name == "Steinbock Sloop"
|
||||
assert results and results[0].name == "Steinbock Sloop"
|
||||
|
||||
# Search by ingredient
|
||||
ingredient_result = database.recipes.page_all(pagination_query, search="moss").items
|
||||
assert len(ingredient_result) == 1
|
||||
assert ingredient_result[0].name == "Fiddlehead Fern Stir Fry"
|
||||
|
||||
# Make sure title matches are ordered in front
|
||||
ordered_result = database.recipes.page_all(pagination_query, search="animal sloop").items
|
||||
assert len(ordered_result) == 2
|
||||
assert ordered_result[0].name == "Animal Sloop"
|
||||
assert ordered_result[1].name == "Steinbock Sloop"
|
||||
|
||||
# Test literal search
|
||||
literal_result = database.recipes.page_all(pagination_query, search='"Animal Sloop"').items
|
||||
assert len(literal_result) == 1
|
||||
assert literal_result[0].name == "Animal Sloop"
|
||||
|
||||
# Test special character removal from non-literal searches
|
||||
character_result = database.recipes.page_all(pagination_query, search="animal-sloop").items
|
||||
assert len(character_result) == 2
|
||||
assert character_result[0].name == "Animal Sloop"
|
||||
assert character_result[1].name == "Steinbock Sloop"
|
||||
|
||||
# Test string normalization
|
||||
normalized_result = database.recipes.page_all(pagination_query, search="ratat").items
|
||||
print([r.name for r in normalized_result])
|
||||
assert len(normalized_result) == 1
|
||||
assert normalized_result[0].name == "Rátàtôuile"
|
||||
|
||||
# Test token separation
|
||||
token_result = database.recipes.page_all(pagination_query, search="delicious horns").items
|
||||
assert len(token_result) == 1
|
||||
assert token_result[0].name == "Steinbock Sloop"
|
||||
|
||||
# Test fuzzy search
|
||||
if database.session.get_bind().name == "postgresql":
|
||||
fuzzy_result = database.recipes.page_all(pagination_query, search="Steinbuck").items
|
||||
assert len(fuzzy_result) == 1
|
||||
assert fuzzy_result[0].name == "Steinbock Sloop"
|
||||
|
||||
# Test random ordering with search
|
||||
pagination_query = PaginationQuery(
|
||||
def test_random_order_recipe_search(
|
||||
database: AllRepositories,
|
||||
search_recipes: list[Recipe], # required so database is populated
|
||||
unique_local_group_id: str,
|
||||
):
|
||||
repo = database.recipes.by_group(unique_local_group_id) # type: ignore
|
||||
pagination = PaginationQuery(
|
||||
page=1,
|
||||
per_page=-1,
|
||||
order_by="random",
|
||||
|
@ -620,7 +654,7 @@ def test_recipe_repo_search(database: AllRepositories, unique_user: TestUser):
|
|||
order_direction=OrderDirection.asc,
|
||||
)
|
||||
random_ordered = []
|
||||
for i in range(5):
|
||||
pagination_query.pagination_seed = str(datetime.now())
|
||||
random_ordered.append(database.recipes.page_all(pagination_query, search="soup").items)
|
||||
for _ in range(5):
|
||||
pagination.pagination_seed = str(datetime.now())
|
||||
random_ordered.append(repo.page_all(pagination, search="soup").items)
|
||||
assert not all(i == random_ordered[0] for i in random_ordered)
|
||||
|
|
135
tests/unit_tests/repository_tests/test_search.py
Normal file
135
tests/unit_tests/repository_tests/test_search.py
Normal file
|
@ -0,0 +1,135 @@
|
|||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from mealie.repos.repository_factory import AllRepositories
|
||||
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit
|
||||
from mealie.schema.response.pagination import OrderDirection, PaginationQuery
|
||||
from mealie.schema.user.user import GroupBase
|
||||
from tests.utils.factories import random_int, random_string
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def unique_local_group_id(database: AllRepositories) -> str:
|
||||
return str(database.groups.create(GroupBase(name=random_string())).id)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def search_units(database: AllRepositories, unique_local_group_id: str) -> list[IngredientUnit]:
|
||||
units = [
|
||||
SaveIngredientUnit(
|
||||
group_id=unique_local_group_id,
|
||||
name="Tea Spoon",
|
||||
abbreviation="tsp",
|
||||
),
|
||||
SaveIngredientUnit(
|
||||
group_id=unique_local_group_id,
|
||||
name="Table Spoon",
|
||||
description="unique description",
|
||||
abbreviation="tbsp",
|
||||
),
|
||||
SaveIngredientUnit(
|
||||
group_id=unique_local_group_id,
|
||||
name="Cup",
|
||||
description="A bucket that's full",
|
||||
),
|
||||
SaveIngredientUnit(
|
||||
group_id=unique_local_group_id,
|
||||
name="Píñch",
|
||||
),
|
||||
SaveIngredientUnit(
|
||||
group_id=unique_local_group_id,
|
||||
name="Unit with a very cool name",
|
||||
),
|
||||
SaveIngredientUnit(
|
||||
group_id=unique_local_group_id,
|
||||
name="Unit with a pretty cool name",
|
||||
),
|
||||
]
|
||||
|
||||
# Add a bunch of units for stable randomization
|
||||
units.extend(
|
||||
[
|
||||
SaveIngredientUnit(group_id=unique_local_group_id, name=f"{random_string()} unit")
|
||||
for _ in range(random_int(12, 20))
|
||||
]
|
||||
)
|
||||
|
||||
return database.ingredient_units.create_many(units)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"search, expected_names",
|
||||
[
|
||||
(random_string(), []),
|
||||
("Cup", ["Cup"]),
|
||||
("tbsp", ["Table Spoon"]),
|
||||
("unique description", ["Table Spoon"]),
|
||||
("very cool name", ["Unit with a very cool name", "Unit with a pretty cool name"]),
|
||||
('"Tea Spoon"', ["Tea Spoon"]),
|
||||
("full bucket", ["Cup"]),
|
||||
],
|
||||
ids=[
|
||||
"no_match",
|
||||
"search_by_name",
|
||||
"search_by_unit",
|
||||
"search_by_description",
|
||||
"match_order",
|
||||
"literal_search",
|
||||
"token_separation",
|
||||
],
|
||||
)
|
||||
def test_basic_search(
|
||||
search: str,
|
||||
expected_names: list[str],
|
||||
database: AllRepositories,
|
||||
search_units: list[IngredientUnit], # required so database is populated
|
||||
unique_local_group_id: str,
|
||||
):
|
||||
repo = database.ingredient_units.by_group(unique_local_group_id)
|
||||
pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
|
||||
results = repo.page_all(pagination, search=search).items
|
||||
|
||||
if len(expected_names) == 0:
|
||||
assert len(results) == 0
|
||||
else:
|
||||
# if more results are returned, that's acceptable, as long as they are ranked correctly
|
||||
assert len(results) >= len(expected_names)
|
||||
for unit, name in zip(results, expected_names, strict=False):
|
||||
assert unit.name == name
|
||||
|
||||
|
||||
def test_fuzzy_search(
|
||||
database: AllRepositories,
|
||||
search_units: list[IngredientUnit], # required so database is populated
|
||||
unique_local_group_id: str,
|
||||
):
|
||||
# this only works on postgres
|
||||
if database.session.get_bind().name != "postgresql":
|
||||
return
|
||||
|
||||
repo = database.ingredient_units.by_group(unique_local_group_id)
|
||||
pagination = PaginationQuery(page=1, per_page=-1, order_by="created_at", order_direction=OrderDirection.asc)
|
||||
results = repo.page_all(pagination, search="unique decsription").items
|
||||
|
||||
assert results and results[0].name == "Table Spoon"
|
||||
|
||||
|
||||
def test_random_order_search(
|
||||
database: AllRepositories,
|
||||
search_units: list[IngredientUnit], # required so database is populated
|
||||
unique_local_group_id: str,
|
||||
):
|
||||
repo = database.ingredient_units.by_group(unique_local_group_id)
|
||||
pagination = PaginationQuery(
|
||||
page=1,
|
||||
per_page=-1,
|
||||
order_by="random",
|
||||
pagination_seed=str(datetime.now()),
|
||||
order_direction=OrderDirection.asc,
|
||||
)
|
||||
random_ordered = []
|
||||
for _ in range(5):
|
||||
pagination.pagination_seed = str(datetime.now())
|
||||
random_ordered.append(repo.page_all(pagination, search="unit").items)
|
||||
assert not all(i == random_ordered[0] for i in random_ordered)
|
Loading…
Add table
Add a link
Reference in a new issue