1
0
Fork 0
mirror of https://github.com/mealie-recipes/mealie.git synced 2025-07-24 07:39:41 +02:00

fix: Filter out null chars from OpenAI response (#5187)

This commit is contained in:
Michael Genson 2025-03-07 10:34:32 -06:00 committed by GitHub
parent 98472ff471
commit 9a469fe4fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 75 additions and 6 deletions

View file

@ -1,7 +1,12 @@
import re
from typing import Self
from pydantic import BaseModel from pydantic import BaseModel
from mealie.core.root_logger import get_logger from mealie.core.root_logger import get_logger
RE_NULLS = re.compile(r"[\x00\u0000]|\\u0000")
logger = get_logger() logger = get_logger()
@ -14,14 +19,26 @@ class OpenAIBase(BaseModel):
__doc__ = "" # we don't want to include the docstring in the JSON schema __doc__ = "" # we don't want to include the docstring in the JSON schema
@classmethod @classmethod
def parse_openai_response(cls, response: str | None): def _preprocess_response(cls, response: str | None) -> str:
""" if not response:
This method should be implemented in the child class. It should return ""
parse the JSON response from OpenAI and return a dictionary.
"""
response = re.sub(RE_NULLS, "", response)
return response
@classmethod
def _process_response(cls, response: str) -> Self:
try: try:
return cls.model_validate_json(response or "") return cls.model_validate_json(response)
except Exception: except Exception:
logger.debug(f"Failed to parse OpenAI response as {cls}. Response: {response}") logger.debug(f"Failed to parse OpenAI response as {cls}. Response: {response}")
raise raise
@classmethod
def parse_openai_response(cls, response: str | None) -> Self:
"""
Parse the OpenAI response into a class instance.
"""
response = cls._preprocess_response(response)
return cls._process_response(response)

View file

@ -1,6 +1,7 @@
import asyncio import asyncio
import json import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import cast
import pytest import pytest
from pydantic import UUID4 from pydantic import UUID4
@ -10,6 +11,7 @@ from mealie.db.db_setup import session_context
from mealie.repos.all_repositories import get_repositories from mealie.repos.all_repositories import get_repositories
from mealie.repos.repository_factory import AllRepositories from mealie.repos.repository_factory import AllRepositories
from mealie.schema.openai.recipe_ingredient import OpenAIIngredient, OpenAIIngredients from mealie.schema.openai.recipe_ingredient import OpenAIIngredient, OpenAIIngredients
from mealie.schema.recipe.recipe import Recipe
from mealie.schema.recipe.recipe_ingredient import ( from mealie.schema.recipe.recipe_ingredient import (
CreateIngredientFood, CreateIngredientFood,
CreateIngredientFoodAlias, CreateIngredientFoodAlias,
@ -26,6 +28,7 @@ from mealie.schema.user.user import GroupBase
from mealie.services.openai import OpenAIService from mealie.services.openai import OpenAIService
from mealie.services.parser_services import RegisteredParser, get_parser from mealie.services.parser_services import RegisteredParser, get_parser
from tests.utils.factories import random_int, random_string from tests.utils.factories import random_int, random_string
from tests.utils.fixture_schemas import TestUser
@dataclass @dataclass
@ -448,3 +451,52 @@ def test_openai_parser(
assert len(parsed) == ingredient_count assert len(parsed) == ingredient_count
for input, output in zip(inputs, parsed, strict=True): for input, output in zip(inputs, parsed, strict=True):
assert output.input == input assert output.input == input
def test_openai_parser_sanitize_output(
unique_local_group_id: UUID4,
unique_user: TestUser,
parsed_ingredient_data: tuple[list[IngredientFood], list[IngredientUnit]], # required so database is populated
monkeypatch: pytest.MonkeyPatch,
):
async def mock_get_response(self, prompt: str, message: str, *args, **kwargs) -> str | None:
data = OpenAIIngredients(
ingredients=[
OpenAIIngredient(
input="there is a null character here: \x00",
confidence=1,
quantity=random_int(0, 10),
unit="",
food="there is a null character here: \x00",
note="",
)
]
)
return data.model_dump_json()
monkeypatch.setattr(OpenAIService, "get_response", mock_get_response)
with session_context() as session:
loop = asyncio.get_event_loop()
parser = get_parser(RegisteredParser.openai, unique_local_group_id, session)
parsed = loop.run_until_complete(parser.parse([""]))
assert len(parsed) == 1
parsed_ing = cast(ParsedIngredient, parsed[0])
assert parsed_ing.input
assert "\x00" not in parsed_ing.input
# Make sure we can create a recipe with this ingredient
assert isinstance(parsed_ing.ingredient.food, CreateIngredientFood)
food = unique_user.repos.ingredient_foods.create(
parsed_ing.ingredient.food.cast(SaveIngredientFood, group_id=unique_user.group_id)
)
parsed_ing.ingredient.food = food
unique_user.repos.recipes.create(
Recipe(
user_id=unique_user.user_id,
group_id=unique_user.group_id,
name=random_string(),
recipe_ingredient=[parsed_ing.ingredient],
)
)