From 9a469fe4fde4d9547d955380075bc6e5e5d0a644 Mon Sep 17 00:00:00 2001 From: Michael Genson <71845777+michael-genson@users.noreply.github.com> Date: Fri, 7 Mar 2025 10:34:32 -0600 Subject: [PATCH] fix: Filter out null chars from OpenAI response (#5187) --- mealie/schema/openai/_base.py | 29 +++++++++--- tests/unit_tests/test_ingredient_parser.py | 52 ++++++++++++++++++++++ 2 files changed, 75 insertions(+), 6 deletions(-) diff --git a/mealie/schema/openai/_base.py b/mealie/schema/openai/_base.py index bcf749665..cc565adcc 100644 --- a/mealie/schema/openai/_base.py +++ b/mealie/schema/openai/_base.py @@ -1,7 +1,12 @@ +import re +from typing import Self + from pydantic import BaseModel from mealie.core.root_logger import get_logger +RE_NULLS = re.compile(r"[\x00\u0000]|\\u0000") + logger = get_logger() @@ -14,14 +19,26 @@ class OpenAIBase(BaseModel): __doc__ = "" # we don't want to include the docstring in the JSON schema @classmethod - def parse_openai_response(cls, response: str | None): - """ - This method should be implemented in the child class. It should - parse the JSON response from OpenAI and return a dictionary. - """ + def _preprocess_response(cls, response: str | None) -> str: + if not response: + return "" + response = re.sub(RE_NULLS, "", response) + return response + + @classmethod + def _process_response(cls, response: str) -> Self: try: - return cls.model_validate_json(response or "") + return cls.model_validate_json(response) except Exception: logger.debug(f"Failed to parse OpenAI response as {cls}. Response: {response}") raise + + @classmethod + def parse_openai_response(cls, response: str | None) -> Self: + """ + Parse the OpenAI response into a class instance. + """ + + response = cls._preprocess_response(response) + return cls._process_response(response) diff --git a/tests/unit_tests/test_ingredient_parser.py b/tests/unit_tests/test_ingredient_parser.py index 5d3d930ae..e4a27429b 100644 --- a/tests/unit_tests/test_ingredient_parser.py +++ b/tests/unit_tests/test_ingredient_parser.py @@ -1,6 +1,7 @@ import asyncio import json from dataclasses import dataclass +from typing import cast import pytest from pydantic import UUID4 @@ -10,6 +11,7 @@ from mealie.db.db_setup import session_context from mealie.repos.all_repositories import get_repositories from mealie.repos.repository_factory import AllRepositories from mealie.schema.openai.recipe_ingredient import OpenAIIngredient, OpenAIIngredients +from mealie.schema.recipe.recipe import Recipe from mealie.schema.recipe.recipe_ingredient import ( CreateIngredientFood, CreateIngredientFoodAlias, @@ -26,6 +28,7 @@ from mealie.schema.user.user import GroupBase from mealie.services.openai import OpenAIService from mealie.services.parser_services import RegisteredParser, get_parser from tests.utils.factories import random_int, random_string +from tests.utils.fixture_schemas import TestUser @dataclass @@ -448,3 +451,52 @@ def test_openai_parser( assert len(parsed) == ingredient_count for input, output in zip(inputs, parsed, strict=True): assert output.input == input + + +def test_openai_parser_sanitize_output( + unique_local_group_id: UUID4, + unique_user: TestUser, + parsed_ingredient_data: tuple[list[IngredientFood], list[IngredientUnit]], # required so database is populated + monkeypatch: pytest.MonkeyPatch, +): + async def mock_get_response(self, prompt: str, message: str, *args, **kwargs) -> str | None: + data = OpenAIIngredients( + ingredients=[ + OpenAIIngredient( + input="there is a null character here: \x00", + confidence=1, + quantity=random_int(0, 10), + unit="", + food="there is a null character here: \x00", + note="", + ) + ] + ) + return data.model_dump_json() + + monkeypatch.setattr(OpenAIService, "get_response", mock_get_response) + + with session_context() as session: + loop = asyncio.get_event_loop() + parser = get_parser(RegisteredParser.openai, unique_local_group_id, session) + + parsed = loop.run_until_complete(parser.parse([""])) + assert len(parsed) == 1 + parsed_ing = cast(ParsedIngredient, parsed[0]) + assert parsed_ing.input + assert "\x00" not in parsed_ing.input + + # Make sure we can create a recipe with this ingredient + assert isinstance(parsed_ing.ingredient.food, CreateIngredientFood) + food = unique_user.repos.ingredient_foods.create( + parsed_ing.ingredient.food.cast(SaveIngredientFood, group_id=unique_user.group_id) + ) + parsed_ing.ingredient.food = food + unique_user.repos.recipes.create( + Recipe( + user_id=unique_user.user_id, + group_id=unique_user.group_id, + name=random_string(), + recipe_ingredient=[parsed_ing.ingredient], + ) + )