import asyncio
import json
import logging
import os
from typing import Any, Dict, List, Optional
from mcp.server.fastmcp import FastMCP
from src.config import get_mem0_client
from starlette.responses import JSONResponse
# Initialize FastMCP server
mcp = FastMCP("mem0-server")
# Initialize Mem0 client
# We initialize it lazily or globally depending on thread safety.
# Mem0 client is generally thread safe for API calls.
memory = get_mem0_client()
# Use @mcp.custom_route instead of @mcp.get
@mcp.custom_route("/health", methods=["GET"])
async def health_check(request):
"""
Standard HTTP health check endpoint.
"""
try:
# Optional: Add a 'Readiness' check for your database
# This confirms Mem0 can actually talk to Qdrant/Neo4j
memory.search("ping", user_id="health_check", limit=1)
return JSONResponse({
"status": "healthy",
"database": "connected"
}, status_code=200)
except Exception as e:
return JSONResponse({
"status": "unhealthy",
"error": str(e)
}, status_code=500)
@mcp.tool()
async def add_memory(content: str, user_id: str, metadata: Optional[Dict[str, Any]] = None) -> str:
"""
Stores a user message into the vector database.
Args:
content: The text message to store.
user_id: The Discord User ID.
metadata: Arbitrary JSON (channel info, timestamp, etc.).
"""
try:
# Wrapper: Mem0 expects messages list or string.
# We pass content directly.
# Format messages as expected by Mem0 for best results if it's a "message"
messages = [{"role": "user", "content": content}]
result = memory.add(messages, user_id=user_id, metadata=metadata)
return json.dumps(result, indent=2)
except Exception as e:
return f"Error adding memory: {str(e)}"
@mcp.tool()
async def add_fact(content: str, user_id: str, metadata: Optional[Dict[str, Any]] = None) -> str:
"""
Stores a specific fact or extracted piece of information.
Args:
content: The fact text (e.g., "User is a typescript developer").
user_id: The Discord User ID.
metadata: Optional source info (confidence, reported_by).
"""
try:
# For facts, we might want to store them slightly differently or just blindly add.
# Mem0 treats additions as "memories".
# We can add a metadata tag if we want to distinguish later, e.g. {"type": "fact"}
# but for now we just pass it through.
if metadata is None:
metadata = {}
metadata["type"] = "fact"
# We pass content directly. Mem0 usually handles raw text well.
result = memory.add(content, user_id=user_id, metadata=metadata)
return json.dumps(result, indent=2)
except Exception as e:
return f"Error adding fact: {str(e)}"
@mcp.tool()
async def search_memories(query: str, user_id: str, limit: int = 5) -> str:
"""
Semantic search for relevant memories.
Args:
query: The search term or current user message.
user_id: Filter results to this specific user.
limit: Max number of results (default: 5).
"""
try:
results = memory.search(query, user_id=user_id, limit=limit)
return json.dumps(results, indent=2)
except Exception as e:
return f"Error searching memory: {str(e)}"
@mcp.tool()
async def get_all_memories(user_id: str, limit: int = 100) -> str:
"""
Retrieve all memories for a user.
Args:
user_id: The unique identifier of the user.
limit: Max number of memories to return.
"""
try:
results = memory.get_all(user_id=user_id, limit=limit)
return json.dumps(results, indent=2)
except Exception as e:
return f"Error getting memories: {str(e)}"
@mcp.tool()
async def database_history(user_id: str) -> str:
"""
Get history of interactions/memories for a user from the graph store (if configured).
Note: Mem0 API might expose this differently. This is a best-effort wrapper.
"""
try:
# Assuming database usage implies getting `history` usually.
# Check Mem0 API for exact method. `get_all` is usually the way.
# If user wants specifically graph history, we might need specific calls.
# preserving this as an alias to get_all for semantic clarity.
return await get_all_memories(user_id)
except Exception as e:
return f"Error getting history: {str(e)}"
if __name__ == "__main__":
# If running in Docker/Background, we might want SSE.
# For now defaulting to stdio which works with `docker exec -i`
transport = os.getenv("MCP_TRANSPORT", "stdio")
if transport == "sse":
# Run using SSE transport (requires uvicorn)
# Binds to 0.0.0.0 to be accessible from outside the container
mcp.run(transport="sse")
else:
mcp.run()