groq_llm.py•4.29 kB
import logging
from typing import Any, Dict, List, Optional
import httpx
from src.llm.base_llm import BaseLLM
class GroqLLM(BaseLLM):
    """
    Implementation of Groq API for the MCP Agent protocol.
    This class provides Groq-specific implementations of the LLM interface,
    including model selection and API communication.
    """
    _GROQ_API_ENDPOINT = "https://api.groq.com/openai/v1/chat/completions"
    # Models available from Groq
    _GROQ_MODELS = [
        "llama-3.2-90b-vision-preview",
        "llama-3.2-8b-chat-preview",
        "llama-3.2-70b-chat",
        "mixtral-8x7b-32768",
        "gemma-7b-it",
    ]
    def __init__(self, api_key: str, model: Optional[str] = None) -> None:
        """
        Initialize the Groq LLM client.
        Args:
            api_key: Groq API key for authentication.
            model: Optional model identifier. If not provided, default_model is used.
        """
        super().__init__(api_key)
        self._model = model if model in self._GROQ_MODELS else self.default_model
    @property
    def default_model(self) -> str:
        """
        Get the default model for Groq.
        Returns:
            The default model identifier as a string.
        """
        return "llama-3.2-90b-vision-preview"
    @property
    def available_models(self) -> List[str]:
        """
        Get a list of available Groq models.
        Returns:
            List of model identifiers as strings.
        """
        return self._GROQ_MODELS.copy()
    @property
    def model(self) -> str:
        """
        Get the currently selected model.
        Returns:
            The current model identifier as a string.
        """
        return self._model
    @model.setter
    def model(self, model_name: str) -> None:
        """
        Set the model to use for requests.
        Args:
            model_name: The model identifier to use.
        Raises:
            ValueError: If the model is not supported by Groq.
        """
        if model_name not in self._GROQ_MODELS:
            raise ValueError(
                f"Model {model_name} is not supported by Groq. Available models: {', '.join(self._GROQ_MODELS)}"
            )
        self._model = model_name
    def get_response(
        self,
        messages: List[Dict[str, str]],
        temperature: float = 0.7,
        max_tokens: int = 4096,
        top_p: float = 1.0,
    ) -> str:
        """
        Get a response from the Groq LLM.
        Args:
            messages: A list of message dictionaries.
            temperature: Controls randomness. Higher values (e.g., 0.8) make output more random,
                         lower values (e.g., 0.2) make it more deterministic.
            max_tokens: Maximum number of tokens to generate.
            top_p: Controls diversity. 1.0 means no filtering.
        Returns:
            The LLM's response as a string.
        Raises:
            httpx.RequestError: If the request to the LLM fails.
        """
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}",
        }
        payload = {
            "messages": messages,
            "model": self._model,
            "temperature": temperature,
            "max_tokens": max_tokens,
            "top_p": top_p,
            "stream": False,
            "stop": None,
        }
        try:
            with httpx.Client() as client:
                response = client.post(
                    self._GROQ_API_ENDPOINT, headers=headers, json=payload
                )
                response.raise_for_status()
                data = response.json()
                return data["choices"][0]["message"]["content"]
        except httpx.RequestError as e:
            error_message = f"Error getting Groq LLM response: {str(e)}"
            logging.error(error_message)
            if isinstance(e, httpx.HTTPStatusError):
                status_code = e.response.status_code
                logging.error(f"Status code: {status_code}")
                logging.error(f"Response details: {e.response.text}")
            return (
                f"I encountered an error: {error_message}. "
                "Please try again or rephrase your request."
            )