Skip to main content
Glama
by mckinsey
_pydantic_output.py4.98 kB
"""Contains the _get_pydantic_model for the Vizro AI dashboard.""" # ruff: noqa: F821 import logging from inspect import signature from typing import Any import plotly.express as px import vizro.models as vm from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import HumanMessage from langchain_core.prompts import ChatPromptTemplate from pydantic import BaseModel, ValidationError logger = logging.getLogger(__name__) BASE_PROMPT = """ You are a front-end developer with expertise in Plotly, Dash, and the visualization library named Vizro. Your goal is to summarize the given specifications into the given Pydantic schema. IMPORTANT: Please always output your response by using a tool. This is the task context: {df_info} Additional information: {additional_info} Here is the user request: """ def _create_prompt_template(additional_info: str) -> ChatPromptTemplate: """Create the ChatPromptTemplate from the base prompt and additional info.""" return ChatPromptTemplate.from_messages( [ ("system", BASE_PROMPT.format(df_info="{df_info}", additional_info=additional_info)), ("placeholder", "{message}"), ] ) SINGLE_MODEL_PROMPT = _create_prompt_template("") MODEL_REPROMPT = _create_prompt_template("Pay special attention to the following error: {validation_error}") def _create_prompt(retry: bool = False) -> ChatPromptTemplate: """Create the prompt message for the LLM model.""" return MODEL_REPROMPT if retry else SINGLE_MODEL_PROMPT def _create_message_content(query: str, df_info: Any, validation_error: str | None = None, retry: bool = False) -> dict: """Create the message content for the LLM model.""" message_content = {"message": [HumanMessage(content=query)], "df_info": df_info} if retry: message_content["validation_error"] = validation_error return message_content def _get_pydantic_model( query: str, llm_model: BaseChatModel, response_model: BaseModel, df_info: Any | None = None, # TODO: this should potentially not be part of this function. max_retry: int = 2, ) -> BaseModel: # TODO: fix typing similar to instructor library, ie the return type should be the same as response_model # At the very least it should include the string type of the validation error """Get the pydantic output from the LLM model with retry logic.""" for attempt in range(max_retry): attempt_is_retry = attempt > 0 prompt = _create_prompt(retry=attempt_is_retry) message_content = _create_message_content( query, df_info, str(last_validation_error) if attempt_is_retry else None, retry=attempt_is_retry ) try: kwargs = {} # Only pass `method` parameter if the model's with_structured_output accepts it # This is determined by checking the signature of the method # By the time this code written, the `method` parameter is supported by # model providers like OpenAI, MistralAI, VertexAI, etc. try: sig = signature(llm_model.with_structured_output) if "method" in sig.parameters: kwargs["method"] = "function_calling" # method 'json_schema' does not work with `pattern` in Field except (ValueError, AttributeError): pass pydantic_llm = prompt | llm_model.with_structured_output(response_model, **kwargs) return pydantic_llm.invoke(message_content) except ValidationError as validation_error: last_validation_error = validation_error # TODO: should this be shifted to logging so that that one can control what output gets shown (e.g. in public demos) raise last_validation_error if __name__ == "__main__": import plotly.express as px import vizro.models as vm from vizro_ai._llm_models import _get_llm_model from vizro_ai.plot._response_models import ChartPlanStatic llm = _get_llm_model() import os from langchain_mistralai import ChatMistralAI llm = ChatMistralAI( name="codestral-latest", temperature=0, max_retries=2, endpoint=os.environ.get("MISTRAL_BASE_URL"), mistral_api_key=os.environ.get("ANTHROPIC_API_KEY"), # other params... ) # Easy component_description = "Create a card with the following content: 'Hello, world!'" res = _get_pydantic_model(query=component_description, llm_model=llm, response_model=vm.Card) print(res.__repr__()) # noqa: T201 print(type(res)) # noqa: T201 # Harder df = px.data.gapminder().sample(5).to_markdown() component_description2 = "the trend of gdp over years in the US" res2 = _get_pydantic_model(query=component_description2, df_info=df, llm_model=llm, response_model=ChartPlanStatic) print(res2.__repr__()) # noqa: T201 print(type(res2)) # noqa: T201

Latest Blog Posts

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/mckinsey/vizro'

If you have feedback or need assistance with the MCP directory API, please join our Discord server