"""
AI Model Integration Module
This module provides secure interfaces to various AI models (Claude, OpenAI, local models)
with comprehensive data privacy controls to ensure sensitive data never leaves your environment.
"""
import os
import logging
import httpx
import json
from typing import Dict, List, Any, Optional, Union, AsyncGenerator
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from .data_privacy import DataPrivacyManager, DataMaskingConfig, PrivacyLevel
logger = logging.getLogger(__name__)
class ModelProvider(Enum):
"""Supported AI model providers."""
CLAUDE = "claude"
OPENAI = "openai"
LOCAL_OLLAMA = "ollama"
LOCAL_LLAMACPP = "llamacpp"
MOCK = "mock" # For testing
@dataclass
class ModelConfig:
"""Configuration for AI models."""
provider: ModelProvider
model_name: str
api_key: Optional[str] = None
base_url: Optional[str] = None
max_tokens: int = 4000
temperature: float = 0.1
timeout: int = 30
privacy_config: Optional[DataMaskingConfig] = None
custom_headers: Dict[str, str] = field(default_factory=dict)
class AIModelInterface(ABC):
"""Abstract base class for AI model interfaces."""
def __init__(self, config: ModelConfig, privacy_manager: DataPrivacyManager):
self.config = config
self.privacy_manager = privacy_manager
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
@abstractmethod
async def generate_response(self, prompt: str, context: Optional[str] = None) -> str:
"""Generate a response from the AI model."""
pass
@abstractmethod
async def generate_sql_query(self, natural_language: str, schema_info: Dict[str, Any]) -> str:
"""Generate SQL query from natural language description."""
pass
@abstractmethod
async def analyze_query_results(self, query: str, results: List[Dict[str, Any]]) -> str:
"""Analyze database query results and provide insights."""
pass
def _sanitize_input(self, text: str) -> str:
"""Sanitize input text based on privacy settings."""
return self.privacy_manager.sanitize_text(text)
def _validate_prompt(self, prompt: str) -> bool:
"""Validate prompt for security issues."""
is_safe, message = self.privacy_manager.validate_ai_prompt(prompt)
if not is_safe:
self.logger.warning(f"Prompt validation failed: {message}")
return is_safe
class ClaudeInterface(AIModelInterface):
"""Interface for Claude AI models."""
def __init__(self, config: ModelConfig, privacy_manager: DataPrivacyManager):
super().__init__(config, privacy_manager)
self.api_key = config.api_key or os.getenv("CLAUDE_API_KEY")
self.base_url = config.base_url or "https://api.anthropic.com"
if not self.api_key:
raise ValueError("Claude API key is required. Set CLAUDE_API_KEY environment variable or provide in config.")
async def generate_response(self, prompt: str, context: Optional[str] = None) -> str:
"""Generate response using Claude API."""
# Sanitize inputs
safe_prompt = self._sanitize_input(prompt)
safe_context = self._sanitize_input(context) if context else None
# Validate prompt
if not self._validate_prompt(safe_prompt):
return "Error: Prompt contains potentially sensitive information and was blocked for security."
# Build messages
messages = []
if safe_context:
messages.append({"role": "user", "content": f"Context: {safe_context}"})
messages.append({"role": "user", "content": safe_prompt})
headers = {
"x-api-key": self.api_key,
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
**self.config.custom_headers
}
payload = {
"model": self.config.model_name,
"max_tokens": self.config.max_tokens,
"temperature": self.config.temperature,
"messages": messages
}
try:
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
response = await client.post(
f"{self.base_url}/v1/messages",
headers=headers,
json=payload
)
response.raise_for_status()
result = response.json()
# Log interaction
self.privacy_manager.log_ai_interaction(
interaction_type="text_generation",
sanitized_input=safe_prompt[:100],
model_used=f"claude-{self.config.model_name}",
response_summary=f"Generated {len(result.get('content', [{}])[0].get('text', ''))} characters"
)
return result.get("content", [{}])[0].get("text", "No response generated")
except httpx.HTTPError as e:
self.logger.error(f"Claude API error: {e}")
return f"Error communicating with Claude: {str(e)}"
except Exception as e:
self.logger.error(f"Unexpected error: {e}")
return f"Unexpected error: {str(e)}"
async def generate_sql_query(self, natural_language: str, schema_info: Dict[str, Any]) -> str:
"""Generate SQL query from natural language using Claude."""
# Sanitize schema info
safe_schema = self.privacy_manager.create_safe_schema_summary(schema_info)
prompt = f"""
Based on the following database schema, generate a SQL query for this request: "{natural_language}"
Schema Information:
{json.dumps(safe_schema, indent=2)}
Requirements:
- Generate only the SQL query, no explanations
- Use proper SQL syntax
- Include appropriate LIMIT clauses for safety
- Only query existing tables and columns
SQL Query:
"""
return await self.generate_response(prompt)
async def analyze_query_results(self, query: str, results: List[Dict[str, Any]]) -> str:
"""Analyze query results using Claude."""
# Sanitize results
safe_results = self.privacy_manager.sanitize_query_result(results)
prompt = f"""
Analyze the following SQL query results and provide insights:
Query: {query}
Results: {json.dumps(safe_results[:10], indent=2)} # Limit to first 10 rows
Total rows: {len(results)}
Please provide:
1. A brief summary of the data
2. Key insights or patterns
3. Any notable findings
Analysis:
"""
return await self.generate_response(prompt)
class OpenAIInterface(AIModelInterface):
"""Interface for OpenAI models."""
def __init__(self, config: ModelConfig, privacy_manager: DataPrivacyManager):
super().__init__(config, privacy_manager)
self.api_key = config.api_key or os.getenv("OPENAI_API_KEY")
self.base_url = config.base_url or "https://api.openai.com"
if not self.api_key:
raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or provide in config.")
async def generate_response(self, prompt: str, context: Optional[str] = None) -> str:
"""Generate response using OpenAI API."""
# Sanitize inputs
safe_prompt = self._sanitize_input(prompt)
safe_context = self._sanitize_input(context) if context else None
# Validate prompt
if not self._validate_prompt(safe_prompt):
return "Error: Prompt contains potentially sensitive information and was blocked for security."
# Build messages
messages = []
if safe_context:
messages.append({"role": "system", "content": f"Context: {safe_context}"})
messages.append({"role": "user", "content": safe_prompt})
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
**self.config.custom_headers
}
payload = {
"model": self.config.model_name,
"messages": messages,
"max_tokens": self.config.max_tokens,
"temperature": self.config.temperature
}
try:
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
response = await client.post(
f"{self.base_url}/v1/chat/completions",
headers=headers,
json=payload
)
response.raise_for_status()
result = response.json()
# Log interaction
self.privacy_manager.log_ai_interaction(
interaction_type="text_generation",
sanitized_input=safe_prompt[:100],
model_used=f"openai-{self.config.model_name}",
response_summary=f"Generated {len(result.get('choices', [{}])[0].get('message', {}).get('content', ''))} characters"
)
return result.get("choices", [{}])[0].get("message", {}).get("content", "No response generated")
except httpx.HTTPError as e:
self.logger.error(f"OpenAI API error: {e}")
return f"Error communicating with OpenAI: {str(e)}"
except Exception as e:
self.logger.error(f"Unexpected error: {e}")
return f"Unexpected error: {str(e)}"
async def generate_sql_query(self, natural_language: str, schema_info: Dict[str, Any]) -> str:
"""Generate SQL query from natural language using OpenAI."""
# Sanitize schema info
safe_schema = self.privacy_manager.create_safe_schema_summary(schema_info)
context = "You are a SQL expert. Generate only the SQL query without explanations."
prompt = f"""
Database Schema:
{json.dumps(safe_schema, indent=2)}
Generate a SQL query for: "{natural_language}"
Requirements:
- Return only the SQL query
- Include LIMIT clauses for safety
- Use existing tables and columns only
"""
return await self.generate_response(prompt, context)
async def analyze_query_results(self, query: str, results: List[Dict[str, Any]]) -> str:
"""Analyze query results using OpenAI."""
# Sanitize results
safe_results = self.privacy_manager.sanitize_query_result(results)
context = "You are a data analyst. Provide concise insights about the query results."
prompt = f"""
SQL Query: {query}
Results: {json.dumps(safe_results[:10], indent=2)}
Total rows: {len(results)}
Provide a brief analysis with key insights and patterns.
"""
return await self.generate_response(prompt, context)
class OllamaInterface(AIModelInterface):
"""Interface for local Ollama models."""
def __init__(self, config: ModelConfig, privacy_manager: DataPrivacyManager):
super().__init__(config, privacy_manager)
self.base_url = config.base_url or "http://localhost:11434"
async def generate_response(self, prompt: str, context: Optional[str] = None) -> str:
"""Generate response using local Ollama API."""
# Note: For local models, we can be less restrictive with privacy
# since data stays local, but still apply basic sanitization
safe_prompt = prompt
if self.privacy_manager.config.privacy_level != PrivacyLevel.NONE:
safe_prompt = self._sanitize_input(prompt)
full_prompt = prompt
if context:
full_prompt = f"Context: {context}\n\n{prompt}"
payload = {
"model": self.config.model_name,
"prompt": full_prompt,
"stream": False,
"options": {
"temperature": self.config.temperature,
"num_predict": self.config.max_tokens
}
}
try:
async with httpx.AsyncClient(timeout=self.config.timeout) as client:
response = await client.post(
f"{self.base_url}/api/generate",
json=payload
)
response.raise_for_status()
result = response.json()
# Log interaction
self.privacy_manager.log_ai_interaction(
interaction_type="text_generation",
sanitized_input=safe_prompt[:100],
model_used=f"ollama-{self.config.model_name}",
response_summary=f"Generated {len(result.get('response', ''))} characters"
)
return result.get("response", "No response generated")
except httpx.HTTPError as e:
self.logger.error(f"Ollama API error: {e}")
return f"Error communicating with Ollama: {str(e)}"
except Exception as e:
self.logger.error(f"Unexpected error: {e}")
return f"Unexpected error: {str(e)}"
async def generate_sql_query(self, natural_language: str, schema_info: Dict[str, Any]) -> str:
"""Generate SQL query using local Ollama model."""
# For local models, we can use full schema info
prompt = f"""
Generate a SQL query for the following request: "{natural_language}"
Available database schema:
{json.dumps(schema_info, indent=2)}
Generate only the SQL query. Include appropriate LIMIT clauses for safety.
SQL:
"""
return await self.generate_response(prompt)
async def analyze_query_results(self, query: str, results: List[Dict[str, Any]]) -> str:
"""Analyze query results using local Ollama model."""
# For local models, we can be more permissive with data
analysis_data = results[:20] if len(results) > 20 else results
prompt = f"""
Analyze the following SQL query and its results:
Query: {query}
Results (showing first 20 rows): {json.dumps(analysis_data, indent=2)}
Total rows: {len(results)}
Provide insights about:
1. Data summary
2. Notable patterns
3. Key findings
Analysis:
"""
return await self.generate_response(prompt)
class MockInterface(AIModelInterface):
"""Mock interface for testing purposes."""
async def generate_response(self, prompt: str, context: Optional[str] = None) -> str:
"""Generate mock response."""
self.privacy_manager.log_ai_interaction(
interaction_type="text_generation",
sanitized_input=prompt[:100],
model_used="mock-model",
response_summary="Mock response generated"
)
return f"Mock response for prompt: {prompt[:50]}..."
async def generate_sql_query(self, natural_language: str, schema_info: Dict[str, Any]) -> str:
"""Generate mock SQL query."""
return f"SELECT * FROM mock_table WHERE description LIKE '%{natural_language}%' LIMIT 10;"
async def analyze_query_results(self, query: str, results: List[Dict[str, Any]]) -> str:
"""Generate mock analysis."""
return f"Mock analysis: Query returned {len(results)} rows with various data patterns."
class AIModelManager:
"""Manager for AI model interfaces with automatic model selection and failover."""
def __init__(self, privacy_config: Optional[DataMaskingConfig] = None):
self.privacy_manager = DataPrivacyManager(privacy_config)
self.models: Dict[str, AIModelInterface] = {}
self.default_model: Optional[str] = None
def register_model(self, name: str, config: ModelConfig, set_as_default: bool = False):
"""Register an AI model."""
interface_classes = {
ModelProvider.CLAUDE: ClaudeInterface,
ModelProvider.OPENAI: OpenAIInterface,
ModelProvider.LOCAL_OLLAMA: OllamaInterface,
ModelProvider.LOCAL_LLAMACPP: OllamaInterface, # Use same interface for now
ModelProvider.MOCK: MockInterface
}
interface_class = interface_classes.get(config.provider)
if not interface_class:
raise ValueError(f"Unsupported model provider: {config.provider}")
try:
self.models[name] = interface_class(config, self.privacy_manager)
if set_as_default or not self.default_model:
self.default_model = name
logger.info(f"Registered AI model: {name} ({config.provider.value})")
except Exception as e:
logger.error(f"Failed to register model {name}: {e}")
raise
def get_model(self, name: Optional[str] = None) -> AIModelInterface:
"""Get a model interface by name or default."""
model_name = name or self.default_model
if not model_name or model_name not in self.models:
raise ValueError(f"Model '{model_name}' not found. Available: {list(self.models.keys())}")
return self.models[model_name]
def list_models(self) -> List[str]:
"""List all registered model names."""
return list(self.models.keys())
async def generate_response(self, prompt: str, model_name: Optional[str] = None,
context: Optional[str] = None) -> str:
"""Generate response using specified or default model."""
model = self.get_model(model_name)
return await model.generate_response(prompt, context)
async def generate_sql_query(self, natural_language: str, schema_info: Dict[str, Any],
model_name: Optional[str] = None) -> str:
"""Generate SQL query using specified or default model."""
model = self.get_model(model_name)
return await model.generate_sql_query(natural_language, schema_info)
async def analyze_query_results(self, query: str, results: List[Dict[str, Any]],
model_name: Optional[str] = None) -> str:
"""Analyze query results using specified or default model."""
model = self.get_model(model_name)
return await model.analyze_query_results(query, results)
# Pre-configured model configurations
DEFAULT_CONFIGS = {
"claude-3-sonnet": ModelConfig(
provider=ModelProvider.CLAUDE,
model_name="claude-3-sonnet-20240229",
max_tokens=4000,
temperature=0.1
),
"claude-3-haiku": ModelConfig(
provider=ModelProvider.CLAUDE,
model_name="claude-3-haiku-20240307",
max_tokens=4000,
temperature=0.1
),
"gpt-4": ModelConfig(
provider=ModelProvider.OPENAI,
model_name="gpt-4",
max_tokens=4000,
temperature=0.1
),
"gpt-3.5-turbo": ModelConfig(
provider=ModelProvider.OPENAI,
model_name="gpt-3.5-turbo",
max_tokens=4000,
temperature=0.1
),
"llama2": ModelConfig(
provider=ModelProvider.LOCAL_OLLAMA,
model_name="llama2",
max_tokens=4000,
temperature=0.1
),
"codellama": ModelConfig(
provider=ModelProvider.LOCAL_OLLAMA,
model_name="codellama",
max_tokens=4000,
temperature=0.1
),
"mock": ModelConfig(
provider=ModelProvider.MOCK,
model_name="mock-model",
max_tokens=4000,
temperature=0.1
)
}