diff --git a/mealie/schema/response/query_filter.py b/mealie/schema/response/query_filter.py index f525356f6..a865ac112 100644 --- a/mealie/schema/response/query_filter.py +++ b/mealie/schema/response/query_filter.py @@ -335,13 +335,25 @@ class QueryFilterBuilder: return current_model, model_attr, query - @staticmethod - def _get_filter_element( - component: QueryFilterBuilderComponent, model, model_attr, model_attr_type - ) -> sa.ColumnElement: + @classmethod + def _transform_model_attr(cls, model_attr: InstrumentedAttribute, model_attr_type: Any) -> InstrumentedAttribute: if isinstance(model_attr_type, sqltypes.String): model_attr = sa.func.lower(model_attr) + return model_attr + + @classmethod + def _get_filter_element( + cls, + query: sa.Select, + component: QueryFilterBuilderComponent, + model: type[Model], + model_attr: InstrumentedAttribute, + model_attr_type: Any, + ) -> sa.ColumnElement: + original_model_attr = model_attr + model_attr = cls._transform_model_attr(model_attr, model_attr_type) + # Keywords if component.relationship is RelationalKeyword.IS: element = model_attr.is_(component.validate(model_attr_type)) @@ -350,7 +362,13 @@ class QueryFilterBuilder: elif component.relationship is RelationalKeyword.IN: element = model_attr.in_(component.validate(model_attr_type)) elif component.relationship is RelationalKeyword.NOT_IN: - element = model_attr.not_in(component.validate(model_attr_type)) + vals = component.validate(model_attr_type) + if original_model_attr.parent.entity != model: + subq = query.with_only_columns(model.id).where(model_attr.in_(vals)) + element = sa.not_(model.id.in_(subq)) + else: + element = sa.not_(model_attr.in_(vals)) + elif component.relationship is RelationalKeyword.CONTAINS_ALL: primary_model_attr: InstrumentedAttribute = getattr(model, component.attribute_name.split(".")[0]) element = sa.and_() @@ -428,7 +446,7 @@ class QueryFilterBuilder: if (column_alias := column_aliases.get(base_attribute_name)) is not None: model_attr = column_alias - element = self._get_filter_element(component, model, model_attr, model_attr.type) + element = self._get_filter_element(query, component, model, model_attr, model_attr.type) partial_group.append(element) # combine the completed groups into one filter diff --git a/tests/unit_tests/repository_tests/test_pagination.py b/tests/unit_tests/repository_tests/test_pagination.py index 606321789..4d82af9e1 100644 --- a/tests/unit_tests/repository_tests/test_pagination.py +++ b/tests/unit_tests/repository_tests/test_pagination.py @@ -312,6 +312,18 @@ def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit, assert unit_2.id in result_ids assert unit_3.id not in result_ids + query = PaginationQuery(page=1, per_page=-1, query_filter=f'name IN ["{unit_3.name}"]') + unit_results = units_repo.page_all(query).items + + result_ids = {unit.id for unit in unit_results} + assert unit_1.id not in result_ids + assert unit_2.id not in result_ids + assert unit_3.id in result_ids + + +def test_pagination_filter_not_in(query_units: tuple[RepositoryUnit, IngredientUnit, IngredientUnit, IngredientUnit]): + units_repo, unit_1, unit_2, unit_3 = query_units + query = PaginationQuery(page=1, per_page=-1, query_filter=f"name NOT IN [{unit_1.name}, {unit_2.name}]") unit_results = units_repo.page_all(query).items @@ -320,13 +332,73 @@ def test_pagination_filter_in(query_units: tuple[RepositoryUnit, IngredientUnit, assert unit_2.id not in result_ids assert unit_3.id in result_ids - query = PaginationQuery(page=1, per_page=-1, query_filter=f'name IN ["{unit_3.name}"]') - unit_results = units_repo.page_all(query).items - result_ids = {unit.id for unit in unit_results} - assert unit_1.id not in result_ids - assert unit_2.id not in result_ids - assert unit_3.id in result_ids +def test_pagination_filter_in_m2m(unique_user: TestUser): + db = unique_user.repos + unique_category_1, unique_category_2, shared_category = ( + db.categories.create(CategorySave(group_id=unique_user.group_id, name=random_string(10))) for _ in range(3) + ) + recipe_1, recipe_2 = ( + db.recipes.create(Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string())) + for _ in range(2) + ) + + recipe_1.recipe_category = [unique_category_1, shared_category] + recipe_2.recipe_category = [unique_category_2, shared_category] + db.recipes.update(recipe_1.slug, recipe_1) + db.recipes.update(recipe_2.slug, recipe_2) + + query = PaginationQuery(page=1, per_page=-1, query_filter=f"recipeCategory.name IN [{shared_category.name}]") + recipe_results = db.recipes.page_all(query).items + assert len(recipe_results) == 2 + assert {recipe.id for recipe in recipe_results} == {recipe_1.id, recipe_2.id} + + +def test_pagination_filter_not_in_m2m(unique_user: TestUser): + db = unique_user.repos + unique_category_1, unique_category_2, shared_category = ( + db.categories.create(CategorySave(group_id=unique_user.group_id, name=random_string(10))) for _ in range(3) + ) + recipe_1, recipe_2 = ( + db.recipes.create(Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string())) + for _ in range(2) + ) + + recipe_1.recipe_category = [unique_category_1, shared_category] + recipe_2.recipe_category = [unique_category_2, shared_category] + db.recipes.update(recipe_1.slug, recipe_1) + db.recipes.update(recipe_2.slug, recipe_2) + + query = PaginationQuery(page=1, per_page=-1, query_filter=f"recipeCategory.name NOT IN [{unique_category_1.name}]") + recipe_results = db.recipes.page_all(query).items + recipe_results_ids = {recipe.id for recipe in recipe_results} + assert recipe_1.id not in recipe_results_ids + assert recipe_2.id in recipe_results_ids + + +def test_pagination_filter_not_in_includes_null(unique_user: TestUser): + db = unique_user.repos + unique_category_1, unique_category_2, shared_category = ( + db.categories.create(CategorySave(group_id=unique_user.group_id, name=random_string(10))) for _ in range(3) + ) + recipe_1, recipe_2, recipe_3 = ( + db.recipes.create(Recipe(user_id=unique_user.user_id, group_id=unique_user.group_id, name=random_string())) + for _ in range(3) + ) + + recipe_1.recipe_category = [unique_category_1, shared_category] + recipe_2.recipe_category = [unique_category_2, shared_category] + db.recipes.update(recipe_1.slug, recipe_1) + db.recipes.update(recipe_2.slug, recipe_2) + + query = PaginationQuery(page=1, per_page=-1, query_filter=f"recipeCategory.name NOT IN [{unique_category_1.name}]") + recipe_results = db.recipes.page_all(query).items + recipe_results_ids = {recipe.id for recipe in recipe_results} + assert recipe_1.id not in recipe_results_ids + assert recipe_2.id in recipe_results_ids + + # this recipe has no categories, and therefore should be included in the results + assert recipe_3.id in recipe_results_ids def test_pagination_filter_in_advanced(unique_user: TestUser):