1
0
Fork 0
mirror of https://github.com/mealie-recipes/mealie.git synced 2025-07-25 08:09:41 +02:00

feat: Advanced Query Filter Record Ordering (#2530)

* added support for multiple order_by strs

* refactored qf to expose nested attr logic

* added nested attr support to order_by

* added tests

* changed unique user to be function-level

* updated docs

* added support for null handling

* updated docs

* undid fixture changes

* fix leaky tests

* added advanced shopping list item test

---------

Co-authored-by: Hayden <64056131+hay-kot@users.noreply.github.com>
This commit is contained in:
Michael Genson 2023-09-14 09:09:05 -05:00 committed by GitHub
parent 2c5e5a8421
commit aec4cb4f31
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 483 additions and 66 deletions

View file

@ -1,3 +1,4 @@
import random
import time
from collections import defaultdict
from datetime import date, datetime, timedelta
@ -7,21 +8,54 @@ from urllib.parse import parse_qsl, urlsplit
import pytest
from fastapi.testclient import TestClient
from humps import camelize
from pydantic import UUID4
from mealie.repos.repository_factory import AllRepositories
from mealie.repos.repository_units import RepositoryUnit
from mealie.schema.group.group_shopping_list import (
ShoppingListItemCreate,
ShoppingListMultiPurposeLabelCreate,
ShoppingListMultiPurposeLabelOut,
ShoppingListSave,
)
from mealie.schema.labels.multi_purpose_label import MultiPurposeLabelSave
from mealie.schema.meal_plan.new_meal import CreatePlanEntry
from mealie.schema.recipe import Recipe
from mealie.schema.recipe.recipe_category import CategorySave, TagSave
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientUnit
from mealie.schema.recipe.recipe_ingredient import IngredientUnit, SaveIngredientFood, SaveIngredientUnit
from mealie.schema.recipe.recipe_tool import RecipeToolSave
from mealie.schema.response.pagination import PaginationQuery
from mealie.schema.response.pagination import OrderByNullPosition, OrderDirection, PaginationQuery
from mealie.services.seeder.seeder_service import SeederService
from tests.utils import api_routes
from tests.utils.factories import random_int, random_string
from tests.utils.fixture_schemas import TestUser
class Reversor:
"""
Enables reversed sorting
https://stackoverflow.com/a/56842689
"""
def __init__(self, obj):
self.obj = obj
def __eq__(self, other):
return other.obj == self.obj
def __lt__(self, other):
return other.obj < self.obj
def get_label_position_from_label_id(label_id: UUID4, label_settings: list[ShoppingListMultiPurposeLabelOut]) -> int:
for label_setting in label_settings:
if label_setting.label_id == label_id:
return label_setting.position
raise Exception("Something went wrong when parsing label settings")
def test_repository_pagination(database: AllRepositories, unique_user: TestUser):
group = database.groups.get_one(unique_user.group_id)
assert group
@ -153,14 +187,6 @@ def query_units(database: AllRepositories, unique_user: TestUser):
unit_ids = [unit.id for unit in [unit_1, unit_2, unit_3]]
units_repo = database.ingredient_units.by_group(unique_user.group_id) # type: ignore
# make sure we can get all of our test units
query = PaginationQuery(page=1, per_page=-1)
all_units = units_repo.page_all(query).items
assert len(all_units) == 3
for unit in all_units:
assert unit.id in unit_ids
yield units_repo, unit_1, unit_2, unit_3
for unit_id in unit_ids:
@ -233,7 +259,6 @@ def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit,
query = PaginationQuery(page=1, per_page=-1, query_filter=f"name IN [{unit_1.name}, {unit_2.name}]")
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 2
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids
@ -242,7 +267,6 @@ def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit,
query = PaginationQuery(page=1, per_page=-1, query_filter=f"name NOT IN [{unit_1.name}, {unit_2.name}]")
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 1
result_ids = {unit.id for unit in unit_results}
assert unit_1.id not in result_ids
assert unit_2.id not in result_ids
@ -251,7 +275,6 @@ def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit,
query = PaginationQuery(page=1, per_page=-1, query_filter=f'name IN ["{unit_3.name}"]')
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 1
result_ids = {unit.id for unit in unit_results}
assert unit_1.id not in result_ids
assert unit_2.id not in result_ids
@ -521,6 +544,282 @@ def test_pagination_filter_datetimes(
assert len(unit_ids) == 0
@pytest.mark.parametrize("order_direction", [OrderDirection.asc, OrderDirection.desc], ids=["ascending", "descending"])
def test_pagination_order_by_multiple(
database: AllRepositories, unique_user: TestUser, order_direction: OrderDirection
):
current_time = datetime.now()
alphabet = ["a", "b", "c", "d", "e"]
abbreviations = alphabet.copy()
descriptions = alphabet.copy()
random.shuffle(abbreviations)
random.shuffle(descriptions)
assert abbreviations != descriptions
units_to_create: list[SaveIngredientUnit] = []
for abbreviation in abbreviations:
for description in descriptions:
units_to_create.append(
SaveIngredientUnit(
group_id=unique_user.group_id,
name=random_string(),
abbreviation=abbreviation,
description=description,
)
)
sorted_units = database.ingredient_units.create_many(units_to_create)
sorted_units.sort(key=lambda x: (x.abbreviation, x.description), reverse=order_direction is OrderDirection.desc)
query = database.ingredient_units.page_all(
PaginationQuery(
page=1,
per_page=-1,
order_by="abbreviation, description",
order_direction=order_direction,
query_filter=f'created_at >= "{current_time.isoformat()}"',
)
)
assert query.items == sorted_units
@pytest.mark.parametrize(
"order_by_str, order_direction",
[
("abbreviation:asc, description:desc", OrderDirection.asc),
("abbreviation:asc, description:desc", OrderDirection.desc),
("abbreviation, description:desc", OrderDirection.asc),
("abbreviation:asc, description", OrderDirection.desc),
],
ids=[
"order_by_asc_explicit_order_bys",
"order_by_desc_explicit_order_bys",
"order_by_asc_inferred_order_by",
"order_by_desc_inferred_order_by",
],
)
def test_pagination_order_by_multiple_directions(
database: AllRepositories, unique_user: TestUser, order_by_str: str, order_direction: OrderDirection
):
current_time = datetime.now()
alphabet = ["a", "b", "c", "d", "e"]
abbreviations = alphabet.copy()
descriptions = alphabet.copy()
random.shuffle(abbreviations)
random.shuffle(descriptions)
assert abbreviations != descriptions
units_to_create: list[SaveIngredientUnit] = []
for abbreviation in abbreviations:
for description in descriptions:
units_to_create.append(
SaveIngredientUnit(
group_id=unique_user.group_id,
name=random_string(),
abbreviation=abbreviation,
description=description,
)
)
sorted_units = database.ingredient_units.create_many(units_to_create)
# sort by abbreviation ascending, description descending
sorted_units.sort(key=lambda x: (x.abbreviation, Reversor(x.description)))
query = database.ingredient_units.page_all(
PaginationQuery(
page=1,
per_page=-1,
order_by=order_by_str,
order_direction=order_direction,
query_filter=f'created_at >= "{current_time.isoformat()}"',
)
)
assert query.items == sorted_units
@pytest.mark.parametrize(
"order_direction",
[OrderDirection.asc, OrderDirection.desc],
ids=["order_ascending", "order_descending"],
)
def test_pagination_order_by_nested_model(
database: AllRepositories, unique_user: TestUser, order_direction: OrderDirection
):
current_time = datetime.now()
alphabet = ["a", "b", "c", "d", "e"]
labels = database.group_multi_purpose_labels.create_many(
[MultiPurposeLabelSave(group_id=unique_user.group_id, name=letter) for letter in alphabet]
)
random.shuffle(labels)
sorted_foods = database.ingredient_foods.create_many(
[SaveIngredientFood(group_id=unique_user.group_id, name=random_string(), label_id=label.id) for label in labels]
)
sorted_foods.sort(key=lambda x: x.label.name, reverse=order_direction is OrderDirection.desc) # type: ignore
query = database.ingredient_foods.page_all(
PaginationQuery(
page=1,
per_page=-1,
order_by="label.name",
order_direction=order_direction,
query_filter=f'created_at >= "{current_time.isoformat()}"',
)
)
assert query.items == sorted_foods
def test_pagination_order_by_doesnt_filter(database: AllRepositories, unique_user: TestUser):
current_time = datetime.now()
label = database.group_multi_purpose_labels.create(
MultiPurposeLabelSave(name=random_string(), group_id=unique_user.group_id)
)
food_with_label = database.ingredient_foods.create(
SaveIngredientFood(name=random_string(), label_id=label.id, group_id=unique_user.group_id)
)
food_without_label = database.ingredient_foods.create(
SaveIngredientFood(name=random_string(), group_id=unique_user.group_id)
)
query = database.ingredient_foods.by_group(unique_user.group_id).page_all(
PaginationQuery(per_page=-1, query_filter=f"created_at>{current_time.isoformat()}", order_by="label.name")
)
assert len(query.items) == 2
found_ids = {item.id for item in query.items}
assert food_with_label.id in found_ids
assert food_without_label.id in found_ids
@pytest.mark.parametrize(
"null_position, order_direction",
[
(OrderByNullPosition.first, OrderDirection.asc),
(OrderByNullPosition.last, OrderDirection.asc),
(OrderByNullPosition.first, OrderDirection.asc),
(OrderByNullPosition.last, OrderDirection.asc),
],
ids=[
"order_by_nulls_first_order_direction_asc",
"order_by_nulls_last_order_direction_asc",
"order_by_nulls_first_order_direction_desc",
"order_by_nulls_last_order_direction_desc",
],
)
def test_pagination_order_by_nulls(
database: AllRepositories,
unique_user: TestUser,
null_position: OrderByNullPosition,
order_direction: OrderDirection,
):
current_time = datetime.now()
label = database.group_multi_purpose_labels.create(
MultiPurposeLabelSave(name=random_string(), group_id=unique_user.group_id)
)
food_with_label = database.ingredient_foods.create(
SaveIngredientFood(name=random_string(), label_id=label.id, group_id=unique_user.group_id)
)
food_without_label = database.ingredient_foods.create(
SaveIngredientFood(name=random_string(), group_id=unique_user.group_id)
)
query = database.ingredient_foods.page_all(
PaginationQuery(
per_page=-1,
query_filter=f"created_at >= {current_time.isoformat()}",
order_by="label.name",
order_by_null_position=null_position,
order_direction=order_direction,
)
)
assert len(query.items) == 2
if null_position is OrderByNullPosition.first:
assert query.items[0] == food_without_label
assert query.items[1] == food_with_label
else:
assert query.items[0] == food_with_label
assert query.items[1] == food_without_label
def test_pagination_shopping_list_items_with_labels(database: AllRepositories, unique_user: TestUser):
# create a shopping list and populate it with some items with labels, and some without labels
shopping_list = database.group_shopping_lists.create(
ShoppingListSave(name=random_string(), group_id=unique_user.group_id)
)
labels = database.group_multi_purpose_labels.create_many(
[MultiPurposeLabelSave(name=random_string(), group_id=unique_user.group_id) for _ in range(8)]
)
random.shuffle(labels)
label_settings = database.shopping_list_multi_purpose_labels.create_many(
[
ShoppingListMultiPurposeLabelCreate(shopping_list_id=shopping_list.id, label_id=label.id, position=i)
for i, label in enumerate(labels)
]
)
random.shuffle(label_settings)
with_labels_positions = list(range(0, random_int(20, 25)))
random.shuffle(with_labels_positions)
items_with_labels = database.group_shopping_list_item.create_many(
[
ShoppingListItemCreate(
note=random_string(),
shopping_list_id=shopping_list.id,
label_id=random.choice(labels).id,
position=position,
)
for position in with_labels_positions
]
)
# sort by item label position ascending, then item position ascending
items_with_labels.sort(
key=lambda x: (
get_label_position_from_label_id(x.label.id, label_settings), # type: ignore[union-attr]
x.position,
)
)
without_labels_positions = list(range(len(with_labels_positions), random_int(5, 10)))
random.shuffle(without_labels_positions)
items_without_labels = database.group_shopping_list_item.create_many(
[
ShoppingListItemCreate(
note=random_string(),
shopping_list_id=shopping_list.id,
label_id=random.choice(labels).id,
position=position,
)
for position in without_labels_positions
]
)
items_without_labels.sort(key=lambda x: x.position)
# verify they're in order
query = database.group_shopping_list_item.page_all(
PaginationQuery(
per_page=-1,
order_by="label.shopping_lists_label_settings.position, position",
order_direction=OrderDirection.asc,
order_by_null_position=OrderByNullPosition.first,
),
)
assert query.items == items_without_labels + items_with_labels
def test_pagination_filter_dates(api_client: TestClient, unique_user: TestUser):
yesterday = date.today() - timedelta(days=1)
today = date.today()
@ -616,7 +915,11 @@ def test_pagination_filter_booleans(query_units: tuple[RepositoryUnit, Ingredien
units_repo = query_units[0]
unit_1 = query_units[1]
query = PaginationQuery(page=1, per_page=-1, query_filter="useAbbreviation=true")
query = PaginationQuery(
page=1,
per_page=-1,
query_filter=f"useAbbreviation=true AND id IN [{', '.join([str(unit.id) for unit in query_units[1:]])}]",
)
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 1
assert unit_results[0].id == unit_1.id
@ -630,7 +933,6 @@ def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, Ingredien
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 2
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids
@ -640,7 +942,6 @@ def test_pagination_filter_advanced(query_units: tuple[RepositoryUnit, Ingredien
query = PaginationQuery(page=1, per_page=-1, query_filter=qf)
unit_results = units_repo.page_all(query).items
assert len(unit_results) == 2
result_ids = {unit.id for unit in unit_results}
assert unit_1.id in result_ids
assert unit_2.id in result_ids