mirror of
https://github.com/mealie-recipes/mealie.git
synced 2025-07-24 15:49:42 +02:00
fix: Filter out null chars from OpenAI response (#5187)
This commit is contained in:
parent
98472ff471
commit
9a469fe4fd
2 changed files with 75 additions and 6 deletions
|
@ -1,7 +1,12 @@
|
|||
import re
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from mealie.core.root_logger import get_logger
|
||||
|
||||
RE_NULLS = re.compile(r"[\x00\u0000]|\\u0000")
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
|
@ -14,14 +19,26 @@ class OpenAIBase(BaseModel):
|
|||
__doc__ = "" # we don't want to include the docstring in the JSON schema
|
||||
|
||||
@classmethod
|
||||
def parse_openai_response(cls, response: str | None):
|
||||
"""
|
||||
This method should be implemented in the child class. It should
|
||||
parse the JSON response from OpenAI and return a dictionary.
|
||||
"""
|
||||
def _preprocess_response(cls, response: str | None) -> str:
|
||||
if not response:
|
||||
return ""
|
||||
|
||||
response = re.sub(RE_NULLS, "", response)
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def _process_response(cls, response: str) -> Self:
|
||||
try:
|
||||
return cls.model_validate_json(response or "")
|
||||
return cls.model_validate_json(response)
|
||||
except Exception:
|
||||
logger.debug(f"Failed to parse OpenAI response as {cls}. Response: {response}")
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def parse_openai_response(cls, response: str | None) -> Self:
|
||||
"""
|
||||
Parse the OpenAI response into a class instance.
|
||||
"""
|
||||
|
||||
response = cls._preprocess_response(response)
|
||||
return cls._process_response(response)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from pydantic import UUID4
|
||||
|
@ -10,6 +11,7 @@ from mealie.db.db_setup import session_context
|
|||
from mealie.repos.all_repositories import get_repositories
|
||||
from mealie.repos.repository_factory import AllRepositories
|
||||
from mealie.schema.openai.recipe_ingredient import OpenAIIngredient, OpenAIIngredients
|
||||
from mealie.schema.recipe.recipe import Recipe
|
||||
from mealie.schema.recipe.recipe_ingredient import (
|
||||
CreateIngredientFood,
|
||||
CreateIngredientFoodAlias,
|
||||
|
@ -26,6 +28,7 @@ from mealie.schema.user.user import GroupBase
|
|||
from mealie.services.openai import OpenAIService
|
||||
from mealie.services.parser_services import RegisteredParser, get_parser
|
||||
from tests.utils.factories import random_int, random_string
|
||||
from tests.utils.fixture_schemas import TestUser
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -448,3 +451,52 @@ def test_openai_parser(
|
|||
assert len(parsed) == ingredient_count
|
||||
for input, output in zip(inputs, parsed, strict=True):
|
||||
assert output.input == input
|
||||
|
||||
|
||||
def test_openai_parser_sanitize_output(
|
||||
unique_local_group_id: UUID4,
|
||||
unique_user: TestUser,
|
||||
parsed_ingredient_data: tuple[list[IngredientFood], list[IngredientUnit]], # required so database is populated
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
async def mock_get_response(self, prompt: str, message: str, *args, **kwargs) -> str | None:
|
||||
data = OpenAIIngredients(
|
||||
ingredients=[
|
||||
OpenAIIngredient(
|
||||
input="there is a null character here: \x00",
|
||||
confidence=1,
|
||||
quantity=random_int(0, 10),
|
||||
unit="",
|
||||
food="there is a null character here: \x00",
|
||||
note="",
|
||||
)
|
||||
]
|
||||
)
|
||||
return data.model_dump_json()
|
||||
|
||||
monkeypatch.setattr(OpenAIService, "get_response", mock_get_response)
|
||||
|
||||
with session_context() as session:
|
||||
loop = asyncio.get_event_loop()
|
||||
parser = get_parser(RegisteredParser.openai, unique_local_group_id, session)
|
||||
|
||||
parsed = loop.run_until_complete(parser.parse([""]))
|
||||
assert len(parsed) == 1
|
||||
parsed_ing = cast(ParsedIngredient, parsed[0])
|
||||
assert parsed_ing.input
|
||||
assert "\x00" not in parsed_ing.input
|
||||
|
||||
# Make sure we can create a recipe with this ingredient
|
||||
assert isinstance(parsed_ing.ingredient.food, CreateIngredientFood)
|
||||
food = unique_user.repos.ingredient_foods.create(
|
||||
parsed_ing.ingredient.food.cast(SaveIngredientFood, group_id=unique_user.group_id)
|
||||
)
|
||||
parsed_ing.ingredient.food = food
|
||||
unique_user.repos.recipes.create(
|
||||
Recipe(
|
||||
user_id=unique_user.user_id,
|
||||
group_id=unique_user.group_id,
|
||||
name=random_string(),
|
||||
recipe_ingredient=[parsed_ing.ingredient],
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue