mcp_integration.py•19 kB
"""
MCP Server Integration with Multi-Provider LLM Support
Integrates the multi-provider LLM system with your MCP server
"""
import os
from typing import Dict, Any, Optional, List, AsyncGenerator
from llmintegrationsystem import (
LLMIntegrationSystem,
LLMProvider,
ChatMessage,
LLMResponse
)
class MCPServerWithMultiLLM:
"""
Enhanced MCP Server with multi-provider LLM support
Integrates with your existing MCP server and chat memory
"""
def __init__(self, app, config: Dict[str, Any]):
# Existing MCP server initialization
self.app = app
self.config = config
self.pool = None # PostgreSQL pool
# Initialize LLM integration system
self.llm = LLMIntegrationSystem()
self._setup_llm_providers()
# Initialize chat memory (from your existing code)
from mcp_chat_memory import ChatMemoryManager
self.memory_manager = ChatMemoryManager(self.pool, config)
def _setup_llm_providers(self):
"""Setup LLM providers based on configuration"""
# Setup Ollama as primary provider (always available locally)
ollama_config = self.config.get('ollama', {})
self.llm.register_provider(
"ollama-primary",
LLMConfig(
provider=LLMProvider.OLLAMA,
model=ollama_config.get('model', 'llama3.2'),
base_url=ollama_config.get('base_url', 'http://localhost:11434'),
temperature=ollama_config.get('temperature', 0.7),
streaming=True,
max_tokens=ollama_config.get('max_tokens', 4096)
),
is_default=True,
is_default_embedding=True
)
# Setup specialized Ollama models
if 'ollama_models' in self.config:
for name, model_config in self.config['ollama_models'].items():
self.llm.register_provider(
name,
LLMConfig(
provider=LLMProvider.OLLAMA,
model=model_config['model'],
base_url=model_config.get('base_url', 'http://localhost:11434'),
temperature=model_config.get('temperature', 0.7),
streaming=model_config.get('streaming', True),
extra_params=model_config.get('extra_params', {})
)
)
# Setup OpenAI if configured
if self.config.get('openai', {}).get('api_key'):
openai_config = self.config['openai']
self.llm.register_provider(
"openai",
LLMConfig(
provider=LLMProvider.OPENAI,
model=openai_config.get('model', 'gpt-4-turbo-preview'),
api_key=openai_config['api_key'],
temperature=openai_config.get('temperature', 0.7),
streaming=True,
max_tokens=openai_config.get('max_tokens', 4096)
),
is_default_embedding=not ollama_config.get('use_for_embeddings', True)
)
# Setup Anthropic/Claude if configured
if self.config.get('anthropic', {}).get('api_key'):
anthropic_config = self.config['anthropic']
self.llm.register_provider(
"claude",
LLMConfig(
provider=LLMProvider.ANTHROPIC,
model=anthropic_config.get('model', 'claude-3-opus-20240229'),
api_key=anthropic_config['api_key'],
temperature=anthropic_config.get('temperature', 0.7),
max_tokens=anthropic_config.get('max_tokens', 4096)
)
)
# Setup Google/Gemini if configured
if self.config.get('google', {}).get('api_key'):
google_config = self.config['google']
self.llm.register_provider(
"gemini",
LLMConfig(
provider=LLMProvider.GOOGLE,
model=google_config.get('model', 'gemini-1.5-pro'),
api_key=google_config['api_key'],
temperature=google_config.get('temperature', 0.7),
max_tokens=google_config.get('max_tokens', 4096)
)
)
self.app.logger.info(f"LLM providers initialized: {self.llm.list_providers()}")
async def generate_sql(self, question: str, schema_context: Dict[str, Any],
provider: Optional[str] = None) -> str:
"""Generate SQL using LLM with specified provider"""
prompt = f"""You are a PostgreSQL expert. Generate a safe SELECT query for the user's question.
Database Schema:
{json.dumps(schema_context, indent=2)}
User Question: {question}
Requirements:
- Only SELECT statements (no INSERT, UPDATE, DELETE, DROP, etc.)
- Use proper JOIN syntax
- Include appropriate WHERE clauses
- Add LIMIT clauses for large result sets (max 100 rows)
- Use PostgreSQL-specific syntax
- Return ONLY the SQL query without markdown formatting or explanations
SQL Query:"""
# Use code-specialized model if available
if provider is None and "ollama-code" in self.llm.providers:
provider = "ollama-code"
response = await self.llm.complete(prompt, provider=provider)
# Clean SQL response
sql = response.strip()
if sql.startswith('```'):
sql = sql.split('```')[1]
if sql.startswith('sql'):
sql = sql[3:]
sql = sql.strip().rstrip(';')
return sql
async def summarize_results(self, question: str, data: List[Dict],
provider: Optional[str] = None) -> str:
"""Summarize query results using LLM"""
prompt = f"""Summarize these database query results to answer the user's question.
Question: {question}
Query Results:
{json.dumps(data[:20], default=str, indent=2)}
{f"... and {len(data)-20} more rows" if len(data) > 20 else ""}
Provide a clear, concise summary in markdown format:
- Start with a direct answer to the question
- Include key statistics or insights
- Use tables for structured data when appropriate
- Keep the summary under 500 words"""
return await self.llm.complete(prompt, provider=provider)
async def process_query_stream(self, question: str, session_id: str,
provider: Optional[str] = None) -> AsyncGenerator[str, None]:
"""Process query with streaming response"""
# Add to memory
if self.memory_manager:
await self.memory_manager.add_message(session_id, 'user', question)
context = await self.memory_manager.get_context(session_id)
else:
context = []
# Get schema context
schema_context = await self._get_schema_context(question)
# Stream SQL generation
yield json.dumps({
'type': 'status',
'content': 'Generating SQL query...'
}) + '\n'
sql_prompt = self._build_sql_prompt(question, schema_context, context)
sql_query = ""
async for chunk in self.llm.stream_complete(sql_prompt, provider=provider):
sql_query += chunk
yield json.dumps({
'type': 'sql_generation',
'content': chunk
}) + '\n'
# Clean SQL
sql_query = self._clean_sql(sql_query)
# Execute SQL
yield json.dumps({
'type': 'status',
'content': 'Executing query...'
}) + '\n'
result = await self.execute_sql_query(sql_query)
if result.get('success'):
yield json.dumps({
'type': 'query_result',
'content': result
}) + '\n'
# Stream summary
yield json.dumps({
'type': 'status',
'content': 'Analyzing results...'
}) + '\n'
summary_prompt = self._build_summary_prompt(question, result['data'], context)
summary = ""
async for chunk in self.llm.stream_complete(summary_prompt, provider=provider):
summary += chunk
yield json.dumps({
'type': 'summary',
'content': chunk
}) + '\n'
# Store in memory
if self.memory_manager:
await self.memory_manager.add_message(session_id, 'assistant', summary)
else:
error_msg = result.get('error', 'Query execution failed')
yield json.dumps({
'type': 'error',
'content': error_msg
}) + '\n'
async def generate_embedding(self, text: str, provider: Optional[str] = None) -> List[float]:
"""Generate embedding using configured provider"""
return await self.llm.embed(text, provider=provider)
async def handle_mcp_request_with_llm(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""Handle MCP request with LLM provider selection"""
method = request_data.get('method')
params = request_data.get('params', {})
# Extract provider preference from params
provider = params.pop('llm_provider', None)
if method == 'llm/providers':
# Return list of available providers
return {
'providers': [
self.llm.get_provider_info(name)
for name in self.llm.providers.keys()
]
}
elif method == 'llm/complete':
# Direct completion
prompt = params.get('prompt')
response = await self.llm.complete(prompt, provider=provider)
return {'response': response}
elif method == 'llm/chat':
# Chat completion
messages = params.get('messages', [])
response = await self.llm.chat(messages, provider=provider)
return {'response': response}
elif method == 'llm/embed':
# Generate embedding
text = params.get('text')
embedding = await self.llm.embed(text, provider=provider)
return {'embedding': embedding}
# Handle other MCP methods...
return await self.handle_mcp_request(request_data)
def _build_sql_prompt(self, question: str, schema_context: Dict,
conversation_context: List) -> str:
"""Build SQL generation prompt with context"""
prompt_parts = []
# Add conversation context if available
if conversation_context:
recent_context = "\n".join([
f"{msg.role}: {msg.content[:200]}"
for msg in conversation_context[-3:]
])
prompt_parts.append(f"Recent conversation:\n{recent_context}\n")
prompt_parts.extend([
"You are a PostgreSQL expert. Generate a safe SELECT query.",
f"\nDatabase Schema:\n{json.dumps(schema_context, indent=2)}",
f"\nUser Question: {question}",
"\nRequirements:",
"- Only SELECT statements",
"- Proper PostgreSQL syntax",
"- Include LIMIT clause (max 100)",
"- Return ONLY the SQL query",
"\nSQL Query:"
])
return "\n".join(prompt_parts)
def _build_summary_prompt(self, question: str, data: List[Dict],
conversation_context: List) -> str:
"""Build summary prompt with context"""
prompt_parts = []
if conversation_context:
prompt_parts.append("Consider the conversation context when summarizing.\n")
prompt_parts.extend([
f"Question: {question}\n",
f"Query returned {len(data)} rows.\n",
f"Data sample:\n{json.dumps(data[:10], default=str, indent=2)}\n",
"Provide a concise markdown summary that directly answers the question."
])
return "\n".join(prompt_parts)
def _clean_sql(self, sql: str) -> str:
"""Clean SQL response from LLM"""
sql = sql.strip()
# Remove markdown code blocks
if '```' in sql:
parts = sql.split('```')
for part in parts:
if 'SELECT' in part.upper():
sql = part
break
# Remove sql language identifier
if sql.lower().startswith('sql'):
sql = sql[3:].strip()
# Remove trailing semicolon for safety
sql = sql.rstrip(';').strip()
return sql
async def cleanup(self):
"""Cleanup resources"""
await self.llm.cleanup()
if self.pool:
await self.pool.close()
# Configuration template
def create_mcp_config() -> Dict[str, Any]:
"""Create MCP configuration with LLM providers"""
return {
# Database configuration
'db_host': os.getenv('DB_HOST', 'localhost'),
'db_port': int(os.getenv('DB_PORT', 5432)),
'db_name': os.getenv('DB_NAME', 'your_database'),
'db_user': os.getenv('DB_USER', 'your_user'),
'db_password': os.getenv('DB_PASSWORD', 'your_password'),
# Ollama configuration (primary/default)
'ollama': {
'model': os.getenv('OLLAMA_MODEL', 'llama3.2'),
'base_url': os.getenv('OLLAMA_URL', 'http://localhost:11434'),
'temperature': float(os.getenv('OLLAMA_TEMP', 0.7)),
'max_tokens': int(os.getenv('OLLAMA_MAX_TOKENS', 4096)),
'use_for_embeddings': True
},
# Specialized Ollama models
'ollama_models': {
'ollama-code': {
'model': 'codellama',
'temperature': 0.1,
'extra_params': {
'num_ctx': 8192, # Larger context for code
'repeat_penalty': 1.1
}
},
'ollama-sql': {
'model': 'sqlcoder', # If you have it
'temperature': 0.1,
'extra_params': {
'stop': [';', '\n\n']
}
},
'ollama-embed': {
'model': 'nomic-embed-text',
'temperature': 0.0
}
},
# OpenAI configuration (optional)
'openai': {
'api_key': os.getenv('OPENAI_API_KEY'),
'model': os.getenv('OPENAI_MODEL', 'gpt-4-turbo-preview'),
'temperature': float(os.getenv('OPENAI_TEMP', 0.7)),
'max_tokens': int(os.getenv('OPENAI_MAX_TOKENS', 4096))
},
# Anthropic configuration (optional)
'anthropic': {
'api_key': os.getenv('ANTHROPIC_API_KEY'),
'model': os.getenv('CLAUDE_MODEL', 'claude-3-opus-20240229'),
'temperature': float(os.getenv('CLAUDE_TEMP', 0.7)),
'max_tokens': int(os.getenv('CLAUDE_MAX_TOKENS', 4096))
},
# Google configuration (optional)
'google': {
'api_key': os.getenv('GOOGLE_API_KEY'),
'model': os.getenv('GEMINI_MODEL', 'gemini-1.5-pro'),
'temperature': float(os.getenv('GEMINI_TEMP', 0.7)),
'max_tokens': int(os.getenv('GEMINI_MAX_TOKENS', 4096))
},
# Memory configuration
'max_context_tokens': int(os.getenv('MAX_CONTEXT_TOKENS', 8000)),
'short_term_window': int(os.getenv('SHORT_TERM_WINDOW', 20)),
'summarization_interval': int(os.getenv('SUMMARIZATION_INTERVAL', 10))
}
# Flask app integration
from flask import Flask, request, jsonify, Response
import json
def create_app():
"""Create Flask app with multi-LLM MCP server"""
app = Flask(__name__)
# Create configuration
config = create_mcp_config()
# Initialize MCP server with multi-LLM support
mcp_server = MCPServerWithMultiLLM(app, config)
@app.route('/mcp', methods=['POST'])
async def mcp_endpoint():
"""Main MCP endpoint"""
data = request.get_json()
response = await mcp_server.handle_mcp_request_with_llm(data)
return jsonify(response)
@app.route('/query/stream', methods=['POST'])
def stream_query():
"""Streaming query endpoint"""
data = request.get_json()
question = data.get('question')
session_id = data.get('session_id', str(uuid4()))
provider = data.get('provider') # Optional provider selection
def generate():
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async def async_gen():
async for chunk in mcp_server.process_query_stream(
question, session_id, provider
):
yield chunk
for item in loop.run_until_complete(async_gen()):
yield item
return Response(
generate(),
mimetype='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'X-Accel-Buffering': 'no'
}
)
@app.route('/llm/providers', methods=['GET'])
def list_providers():
"""List available LLM providers"""
return jsonify({
'providers': [
mcp_server.llm.get_provider_info(name)
for name in mcp_server.llm.providers.keys()
],
'default': mcp_server.llm.default_provider,
'default_embedding': mcp_server.llm.default_embedding_provider
})
@app.route('/llm/test', methods=['POST'])
async def test_provider():
"""Test a specific LLM provider"""
data = request.get_json()
provider = data.get('provider')
prompt = data.get('prompt', 'Hello, please respond with "OK" if you are working.')
try:
response = await mcp_server.llm.complete(prompt, provider=provider)
return jsonify({
'success': True,
'provider': provider,
'response': response
})
except Exception as e:
return jsonify({
'success': False,
'provider': provider,
'error': str(e)
}), 500
return app