mirror of
https://github.com/mealie-recipes/mealie.git
synced 2025-08-03 12:35:22 +02:00
feat: advanced filtering API (#1468)
* created query filter classes * extended pagination to include query filtering * added filtering tests * type improvements * move type help to dev depedency * minor type and perf fixes * breakup test cases Co-authored-by: Hayden <64056131+hay-kot@users.noreply.github.com>
This commit is contained in:
parent
c64da1fdb7
commit
7f50071312
8 changed files with 480 additions and 353 deletions
|
@ -1,9 +1,14 @@
|
|||
import time
|
||||
from random import randint
|
||||
from urllib.parse import parse_qsl, urlsplit
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from humps import camelize
|
||||
|
||||
from mealie.repos.repository_factory import AllRepositories
|
||||
from mealie.repos.repository_units import RepositoryUnit
|
||||
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit
|
||||
from mealie.schema.response.pagination import PaginationQuery
|
||||
from mealie.services.seeder.seeder_service import SeederService
|
||||
from tests.utils.fixture_schemas import TestUser
|
||||
|
@ -82,20 +87,14 @@ def test_pagination_guides(database: AllRepositories, unique_user: TestUser):
|
|||
"/foods" # this doesn't actually have to be accurate, it's just a placeholder to test for query params
|
||||
)
|
||||
|
||||
query = PaginationQuery(
|
||||
page=1,
|
||||
per_page=1,
|
||||
)
|
||||
query = PaginationQuery(page=1, per_page=1)
|
||||
|
||||
first_page_of_results = foods_repo.page_all(query)
|
||||
first_page_of_results.set_pagination_guides(foods_route, query.dict())
|
||||
assert first_page_of_results.next is not None
|
||||
assert first_page_of_results.previous is None
|
||||
|
||||
query = PaginationQuery(
|
||||
page=-1,
|
||||
per_page=1,
|
||||
)
|
||||
query = PaginationQuery(page=-1, per_page=1)
|
||||
|
||||
last_page_of_results = foods_repo.page_all(query)
|
||||
last_page_of_results.set_pagination_guides(foods_route, query.dict())
|
||||
|
@ -103,10 +102,7 @@ def test_pagination_guides(database: AllRepositories, unique_user: TestUser):
|
|||
assert last_page_of_results.previous is not None
|
||||
|
||||
random_page = randint(2, first_page_of_results.total_pages - 1)
|
||||
query = PaginationQuery(
|
||||
page=random_page,
|
||||
per_page=1,
|
||||
)
|
||||
query = PaginationQuery(page=random_page, per_page=1, filter_string="createdAt>2021-02-22")
|
||||
|
||||
random_page_of_results = foods_repo.page_all(query)
|
||||
random_page_of_results.set_pagination_guides(foods_route, query.dict())
|
||||
|
@ -121,3 +117,102 @@ def test_pagination_guides(database: AllRepositories, unique_user: TestUser):
|
|||
for source_param in source_params:
|
||||
assert source_param in next_params
|
||||
assert source_param in prev_params
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def query_units(database: AllRepositories, unique_user: TestUser):
|
||||
unit_1 = database.ingredient_units.create(
|
||||
SaveIngredientUnit(name="test unit 1", group_id=unique_user.group_id, use_abbreviation=True)
|
||||
)
|
||||
|
||||
# wait a moment so we can test datetime filters
|
||||
time.sleep(0.25)
|
||||
|
||||
unit_2 = database.ingredient_units.create(
|
||||
SaveIngredientUnit(name="test unit 2", group_id=unique_user.group_id, use_abbreviation=False)
|
||||
)
|
||||
|
||||
# wait a moment so we can test datetime filters
|
||||
time.sleep(0.25)
|
||||
|
||||
unit_3 = database.ingredient_units.create(
|
||||
SaveIngredientUnit(name="test unit 3", group_id=unique_user.group_id, use_abbreviation=False)
|
||||
)
|
||||
|
||||
unit_ids = [unit.id for unit in [unit_1, unit_2, unit_3]]
|
||||
units_repo = database.ingredient_units.by_group(unique_user.group_id) # type: ignore
|
||||
|
||||
# make sure we can get all of our test units
|
||||
query = PaginationQuery(page=1, per_page=-1)
|
||||
all_units = units_repo.page_all(query).items
|
||||
assert len(all_units) == 3
|
||||
|
||||
for unit in all_units:
|
||||
assert unit.id in unit_ids
|
||||
|
||||
yield units_repo, unit_1, unit_2, unit_3
|
||||
|
||||
for unit_id in unit_ids:
|
||||
units_repo.delete(unit_id)
|
||||
|
||||
|
||||
def test_pagination_filter_basic(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]):
|
||||
units_repo = query_units[0]
|
||||
unit_2 = query_units[2]
|
||||
|
||||
query = PaginationQuery(page=1, per_page=-1, query_filter='name="test unit 2"')
|
||||
unit_results = units_repo.page_all(query).items
|
||||
assert len(unit_results) == 1
|
||||
assert unit_results[0].id == unit_2.id
|
||||
|
||||
|
||||
def test_pagination_filter_datetimes(
|
||||
query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]
|
||||
):
|
||||
units_repo = query_units[0]
|
||||
unit_1 = query_units[1]
|
||||
unit_2 = query_units[2]
|
||||
|
||||
dt = unit_2.created_at.isoformat()
|
||||
query = PaginationQuery(page=1, per_page=-1, query_filter=f'createdAt>="{dt}"')
|
||||
unit_results = units_repo.page_all(query).items
|
||||
assert len(unit_results) == 2
|
||||
assert unit_1.id not in [unit.id for unit in unit_results]
|
||||
|
||||
|
||||
def test_pagination_filter_booleans(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]):
|
||||
units_repo = query_units[0]
|
||||
unit_1 = query_units[1]
|
||||
|
||||
query = PaginationQuery(page=1, per_page=-1, query_filter="useAbbreviation=true")
|
||||
unit_results = units_repo.page_all(query).items
|
||||
assert len(unit_results) == 1
|
||||
assert unit_results[0].id == unit_1.id
|
||||
|
||||
|
||||
def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]):
|
||||
units_repo = query_units[0]
|
||||
unit_3 = query_units[3]
|
||||
|
||||
dt = unit_3.created_at.isoformat()
|
||||
qf = f'name="test unit 1" OR (useAbbreviation=f AND (name="test unit 2" OR createdAt > "{dt}"))'
|
||||
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
|
||||
unit_results = units_repo.page_all(query).items
|
||||
assert len(unit_results) == 2
|
||||
assert unit_3.id not in [unit.id for unit in unit_results]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"qf",
|
||||
[
|
||||
pytest.param('(name="test name" AND useAbbreviation=f))', id="unbalanced parenthesis"),
|
||||
pytest.param('createdAt="this is not a valid datetime format"', id="invalid datetime format"),
|
||||
pytest.param('badAttribute="test value"', id="invalid attribute"),
|
||||
],
|
||||
)
|
||||
def test_malformed_query_filters(api_client: TestClient, unique_user: TestUser, qf: str):
|
||||
# verify that improper queries throw 400 errors
|
||||
route = "/api/units"
|
||||
|
||||
response = api_client.get(route, params={"queryFilter": qf}, headers=unique_user.token)
|
||||
assert response.status_code == 400
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue