"""Rule-based mock LLM for deterministic intent extraction."""
from __future__ import annotations
import json
import re
from typing import Any
from pydantic import BaseModel, Field
try:
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, ChatResult
except ImportError: # pragma: no cover - local fallback for offline environments
class BaseMessage: # type: ignore[override]
def __init__(self, content: str) -> None:
self.content = content
class HumanMessage(BaseMessage): # type: ignore[override]
pass
class AIMessage(BaseMessage): # type: ignore[override]
pass
class ChatGeneration: # type: ignore[override]
def __init__(self, message: AIMessage) -> None:
self.message = message
class ChatResult: # type: ignore[override]
def __init__(self, generations: list[ChatGeneration]) -> None:
self.generations = generations
class BaseChatModel: # type: ignore[override]
@property
def _llm_type(self) -> str:
return "fallback_base_model"
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: Any | None = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError
def invoke(self, messages: list[BaseMessage]) -> AIMessage:
result = self._generate(messages=messages)
return result.generations[0].message
class ParsedIntent(BaseModel):
"""Structured representation of user intent."""
action: str = Field(default="unknown")
params: dict[str, Any] = Field(default_factory=dict)
def _extract_last_user_message(messages: list[BaseMessage]) -> str:
for message in reversed(messages):
if isinstance(message, HumanMessage):
return str(message.content)
return ""
def _to_float(raw: str) -> float:
return float(raw.replace(",", "."))
def _parse_query(query: str) -> ParsedIntent:
text = query.strip()
text_lc = text.lower()
discount_match = re.search(
r"скидк\w*\s+(?P<discount>\d+(?:[.,]\d+)?)%.*?\bid\b\s*(?P<id>\d+)",
text_lc,
flags=re.IGNORECASE,
)
if discount_match:
return ParsedIntent(
action="discount_by_id",
params={
"discount_percent": _to_float(discount_match.group("discount")),
"product_id": int(discount_match.group("id")),
},
)
add_match = re.search(
(
r"добав[^\:]*:\s*(?P<name>[^,]+),\s*цен[аы]?\s*(?P<price>\d+(?:[.,]\d+)?),"
r"\s*категори[яи]\s*(?P<category>[^,]+)"
),
text,
flags=re.IGNORECASE,
)
if add_match:
return ParsedIntent(
action="add_product",
params={
"name": add_match.group("name").strip(),
"price": _to_float(add_match.group("price")),
"category": add_match.group("category").strip(),
"in_stock": True,
},
)
if "средн" in text_lc and "цен" in text_lc:
return ParsedIntent(action="get_average_price")
if "категори" in text_lc:
category_match = re.search(r"категори[яи]\s+([^\.,\n]+)", text, flags=re.IGNORECASE)
if category_match:
return ParsedIntent(
action="list_products_by_category",
params={"category": category_match.group(1).strip()},
)
if "все продукт" in text_lc or "покажи продукт" in text_lc:
return ParsedIntent(action="list_products")
return ParsedIntent(action="unknown")
class RuleBasedMockLLM(BaseChatModel):
"""Mock chat model that produces deterministic JSON plans."""
@property
def _llm_type(self) -> str:
return "rule_based_mock_llm"
def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: Any | None = None,
**kwargs: Any,
) -> ChatResult:
del stop, run_manager, kwargs
query = _extract_last_user_message(messages)
intent = _parse_query(query).model_dump()
content = json.dumps(intent, ensure_ascii=False)
generation = ChatGeneration(message=AIMessage(content=content))
return ChatResult(generations=[generation])