mirror of
https://github.com/mealie-recipes/mealie.git
synced 2025-07-24 07:39:41 +02:00
fix: Filter out null chars from OpenAI response (#5187)
This commit is contained in:
parent
98472ff471
commit
9a469fe4fd
2 changed files with 75 additions and 6 deletions
|
@ -1,7 +1,12 @@
|
||||||
|
import re
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from mealie.core.root_logger import get_logger
|
from mealie.core.root_logger import get_logger
|
||||||
|
|
||||||
|
RE_NULLS = re.compile(r"[\x00\u0000]|\\u0000")
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,14 +19,26 @@ class OpenAIBase(BaseModel):
|
||||||
__doc__ = "" # we don't want to include the docstring in the JSON schema
|
__doc__ = "" # we don't want to include the docstring in the JSON schema
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_openai_response(cls, response: str | None):
|
def _preprocess_response(cls, response: str | None) -> str:
|
||||||
"""
|
if not response:
|
||||||
This method should be implemented in the child class. It should
|
return ""
|
||||||
parse the JSON response from OpenAI and return a dictionary.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
response = re.sub(RE_NULLS, "", response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _process_response(cls, response: str) -> Self:
|
||||||
try:
|
try:
|
||||||
return cls.model_validate_json(response or "")
|
return cls.model_validate_json(response)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug(f"Failed to parse OpenAI response as {cls}. Response: {response}")
|
logger.debug(f"Failed to parse OpenAI response as {cls}. Response: {response}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_openai_response(cls, response: str | None) -> Self:
|
||||||
|
"""
|
||||||
|
Parse the OpenAI response into a class instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = cls._preprocess_response(response)
|
||||||
|
return cls._process_response(response)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import UUID4
|
from pydantic import UUID4
|
||||||
|
@ -10,6 +11,7 @@ from mealie.db.db_setup import session_context
|
||||||
from mealie.repos.all_repositories import get_repositories
|
from mealie.repos.all_repositories import get_repositories
|
||||||
from mealie.repos.repository_factory import AllRepositories
|
from mealie.repos.repository_factory import AllRepositories
|
||||||
from mealie.schema.openai.recipe_ingredient import OpenAIIngredient, OpenAIIngredients
|
from mealie.schema.openai.recipe_ingredient import OpenAIIngredient, OpenAIIngredients
|
||||||
|
from mealie.schema.recipe.recipe import Recipe
|
||||||
from mealie.schema.recipe.recipe_ingredient import (
|
from mealie.schema.recipe.recipe_ingredient import (
|
||||||
CreateIngredientFood,
|
CreateIngredientFood,
|
||||||
CreateIngredientFoodAlias,
|
CreateIngredientFoodAlias,
|
||||||
|
@ -26,6 +28,7 @@ from mealie.schema.user.user import GroupBase
|
||||||
from mealie.services.openai import OpenAIService
|
from mealie.services.openai import OpenAIService
|
||||||
from mealie.services.parser_services import RegisteredParser, get_parser
|
from mealie.services.parser_services import RegisteredParser, get_parser
|
||||||
from tests.utils.factories import random_int, random_string
|
from tests.utils.factories import random_int, random_string
|
||||||
|
from tests.utils.fixture_schemas import TestUser
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -448,3 +451,52 @@ def test_openai_parser(
|
||||||
assert len(parsed) == ingredient_count
|
assert len(parsed) == ingredient_count
|
||||||
for input, output in zip(inputs, parsed, strict=True):
|
for input, output in zip(inputs, parsed, strict=True):
|
||||||
assert output.input == input
|
assert output.input == input
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_parser_sanitize_output(
|
||||||
|
unique_local_group_id: UUID4,
|
||||||
|
unique_user: TestUser,
|
||||||
|
parsed_ingredient_data: tuple[list[IngredientFood], list[IngredientUnit]], # required so database is populated
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
):
|
||||||
|
async def mock_get_response(self, prompt: str, message: str, *args, **kwargs) -> str | None:
|
||||||
|
data = OpenAIIngredients(
|
||||||
|
ingredients=[
|
||||||
|
OpenAIIngredient(
|
||||||
|
input="there is a null character here: \x00",
|
||||||
|
confidence=1,
|
||||||
|
quantity=random_int(0, 10),
|
||||||
|
unit="",
|
||||||
|
food="there is a null character here: \x00",
|
||||||
|
note="",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return data.model_dump_json()
|
||||||
|
|
||||||
|
monkeypatch.setattr(OpenAIService, "get_response", mock_get_response)
|
||||||
|
|
||||||
|
with session_context() as session:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
parser = get_parser(RegisteredParser.openai, unique_local_group_id, session)
|
||||||
|
|
||||||
|
parsed = loop.run_until_complete(parser.parse([""]))
|
||||||
|
assert len(parsed) == 1
|
||||||
|
parsed_ing = cast(ParsedIngredient, parsed[0])
|
||||||
|
assert parsed_ing.input
|
||||||
|
assert "\x00" not in parsed_ing.input
|
||||||
|
|
||||||
|
# Make sure we can create a recipe with this ingredient
|
||||||
|
assert isinstance(parsed_ing.ingredient.food, CreateIngredientFood)
|
||||||
|
food = unique_user.repos.ingredient_foods.create(
|
||||||
|
parsed_ing.ingredient.food.cast(SaveIngredientFood, group_id=unique_user.group_id)
|
||||||
|
)
|
||||||
|
parsed_ing.ingredient.food = food
|
||||||
|
unique_user.repos.recipes.create(
|
||||||
|
Recipe(
|
||||||
|
user_id=unique_user.user_id,
|
||||||
|
group_id=unique_user.group_id,
|
||||||
|
name=random_string(),
|
||||||
|
recipe_ingredient=[parsed_ing.ingredient],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue