"""
Gemma3 client wrapper for Ollama integration.
This module provides a simple interface to interact with Gemma3 8B-Instruct
model through Ollama for processing business documents and extracting insights.
"""
import json
import logging
import os
from typing import Dict, Any, Optional, Union
import asyncio
from functools import wraps
from dotenv import load_dotenv
import ollama
from ollama import Client
from prompts import get_analysis_config
# Load environment variables from .env file
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Gemma3ClientError(Exception):
"""Base exception for Gemma3 client errors."""
pass
class Gemma3ConnectionError(Gemma3ClientError):
"""Exception raised when connection to Ollama fails."""
pass
class Gemma3ProcessingError(Gemma3ClientError):
"""Exception raised when Gemma3 processing fails."""
pass
def retry_on_failure(max_retries: int = 3, delay: float = 1.0):
"""Decorator to retry operations on failure."""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
for attempt in range(max_retries):
try:
return await func(*args, **kwargs)
except Exception as e:
if attempt == max_retries - 1:
raise
logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying in {delay}s...")
await asyncio.sleep(delay)
return None
return wrapper
return decorator
class Gemma3Client:
"""
Client for interacting with Gemma3 8B-Instruct model via Ollama.
Provides methods for processing different types of content with appropriate
configurations and error handling.
"""
# Default configuration based on project specifications
DEFAULT_CONFIG = {
"model": "gemma3:4b", # Using gemma3 4B model as available
"temperature": 0.1, # Low for consistent extraction
"num_ctx": 4096, # Context window (smaller for 4B model)
"top_p": 0.9,
"repeat_penalty": 1.1,
"num_predict": 800 # Maximum tokens to generate (adjusted for 4B)
}
def __init__(self,
host: str = "http://localhost:11434",
model: str = "gemma3n:e4b",
timeout: int = 30):
"""
Initialize Gemma3 client.
Args:
host: Ollama server host URL
model: Model name to use
timeout: Request timeout in seconds
"""
self.host = host
self.model = model
self.timeout = timeout
self.client = Client(host=host)
# Test connection on initialization
self._test_connection()
def _test_connection(self) -> None:
"""Test connection to Ollama server."""
try:
# Try to list models to test connection
models = self.client.list()
logger.info(f"Connected to Ollama at {self.host}")
logger.info(f"Available models: {[m.model for m in models.models]}")
# Check if our model is available
available_models = [m.model for m in models.models]
if self.model not in available_models:
logger.warning(f"Model {self.model} not found. Available: {available_models}")
# Try to pull the model
logger.info(f"Attempting to pull model {self.model}...")
self.client.pull(self.model)
logger.info(f"Successfully pulled model {self.model}")
except Exception as e:
raise Gemma3ConnectionError(f"Failed to connect to Ollama: {e}")
def _prepare_config(self, analysis_type: str = "document_processing") -> Dict[str, Any]:
"""
Prepare configuration for specific analysis type.
Args:
analysis_type: Type of analysis being performed
Returns:
Configuration dictionary for Ollama
"""
config = self.DEFAULT_CONFIG.copy()
analysis_config = get_analysis_config(analysis_type)
# Update with analysis-specific settings
if "temperature" in analysis_config:
config["temperature"] = analysis_config["temperature"]
if "max_tokens" in analysis_config:
config["num_predict"] = analysis_config["max_tokens"]
return config
@retry_on_failure(max_retries=3, delay=2.0)
async def generate_async(self,
prompt: str,
analysis_type: str = "document_processing") -> str:
"""
Generate response asynchronously using Gemma3.
Args:
prompt: Input prompt for the model
analysis_type: Type of analysis for configuration
Returns:
Generated response text
Raises:
Gemma3ProcessingError: If generation fails
"""
try:
config = self._prepare_config(analysis_type)
response = await asyncio.to_thread(
self.client.generate,
model=self.model,
prompt=prompt,
options=config
)
if not response or 'response' not in response:
raise Gemma3ProcessingError("Empty response from Gemma3")
return response['response'].strip()
except Exception as e:
logger.error(f"Gemma3 generation failed: {e}")
raise Gemma3ProcessingError(f"Generation failed: {e}")
def generate(self,
prompt: str,
analysis_type: str = "document_processing") -> str:
"""
Generate response synchronously using Gemma3.
Args:
prompt: Input prompt for the model
analysis_type: Type of analysis for configuration
Returns:
Generated response text
"""
try:
config = self._prepare_config(analysis_type)
response = self.client.generate(
model=self.model,
prompt=prompt,
options=config
)
if not response or 'response' not in response:
raise Gemma3ProcessingError("Empty response from Gemma3")
return response['response'].strip()
except Exception as e:
logger.error(f"Gemma3 generation failed: {e}")
raise Gemma3ProcessingError(f"Generation failed: {e}")
def parse_json_response(self, response: str) -> Dict[str, Any]:
"""
Parse JSON response from Gemma3, with error handling.
Args:
response: Raw response text from Gemma3
Returns:
Parsed JSON dictionary
Raises:
Gemma3ProcessingError: If JSON parsing fails
"""
try:
# Clean up response - remove code blocks if present
cleaned_response = response.strip()
if cleaned_response.startswith("```json"):
cleaned_response = cleaned_response[7:]
if cleaned_response.endswith("```"):
cleaned_response = cleaned_response[:-3]
cleaned_response = cleaned_response.strip()
# Parse JSON
parsed = json.loads(cleaned_response)
return parsed
except json.JSONDecodeError as e:
logger.error(f"Failed to parse JSON response: {e}")
logger.error(f"Raw response: {response}")
# Return a fallback structure
return {
"error": "Failed to parse response",
"raw_response": response,
"parse_error": str(e)
}
async def process_with_json_async(self,
prompt: str,
analysis_type: str = "document_processing") -> Dict[str, Any]:
"""
Process prompt and return parsed JSON response asynchronously.
Args:
prompt: Input prompt for the model
analysis_type: Type of analysis for configuration
Returns:
Parsed JSON response
"""
response = await self.generate_async(prompt, analysis_type)
return self.parse_json_response(response)
def process_with_json(self,
prompt: str,
analysis_type: str = "document_processing") -> Dict[str, Any]:
"""
Process prompt and return parsed JSON response synchronously.
Args:
prompt: Input prompt for the model
analysis_type: Type of analysis for configuration
Returns:
Parsed JSON response
"""
response = self.generate(prompt, analysis_type)
return self.parse_json_response(response)
def health_check(self) -> Dict[str, Any]:
"""
Perform health check on the Gemma3 client.
Returns:
Health status information
"""
try:
# Test with a simple prompt
test_prompt = "Return JSON: {\"status\": \"healthy\", \"model\": \"gemma3\"}"
response = self.generate(test_prompt, "document_processing")
return {
"status": "healthy",
"host": self.host,
"model": self.model,
"test_response": response[:100] + "..." if len(response) > 100 else response
}
except Exception as e:
return {
"status": "unhealthy",
"host": self.host,
"model": self.model,
"error": str(e)
}
# Global client instance (singleton pattern)
_global_client: Optional[Gemma3Client] = None
def get_gemma3_client(host: Optional[str] = None,
model: Optional[str] = None) -> Gemma3Client:
"""
Get or create global Gemma3 client instance.
Args:
host: Ollama server host URL (defaults to env OLLAMA_HOST or localhost:11434)
model: Model name to use (defaults to env OLLAMA_MODEL or gemma3n:e4b)
Returns:
Gemma3Client instance
"""
global _global_client
if _global_client is None:
# Use environment variables if not provided
host = host or os.getenv("OLLAMA_HOST", "http://localhost:11434")
model = model or os.getenv("OLLAMA_MODEL", "gemma3n:e4b")
_global_client = Gemma3Client(host=host, model=model)
return _global_client
def reset_global_client():
"""Reset the global client instance."""
global _global_client
_global_client = None