memory_mcp_server.py•4.62 kB
from mcp.server.fastmcp import FastMCP, Context
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator
from dataclasses import dataclass
from dotenv import load_dotenv
from mem0 import Memory
import asyncio
import json
import os
from utils import get_mem0_client
load_dotenv()
# Default user ID for memory operations
DEFAULT_USER_ID = "user"
# Create a dataclass for our application context
@dataclass
class Mem0Context:
"""Context for the Mem0 MCP server."""
mem0_client: Memory
@asynccontextmanager
async def mem0_lifespan(server: FastMCP) -> AsyncIterator[Mem0Context]:
"""
Manages the Mem0 client lifecycle.
Args:
server: The FastMCP server instance
Yields:
Mem0Context: The context containing the Mem0 client
"""
# Create and return the Memory client with the helper function in utils.py
mem0_client = get_mem0_client()
try:
yield Mem0Context(mem0_client=mem0_client)
finally:
# No explicit cleanup needed for the Mem0 client
pass
# Initialize FastMCP server with the Mem0 client as context
mcp = FastMCP(
"mcp-mem0",
description="MCP server for long term memory storage and retrieval with Mem0",
lifespan=mem0_lifespan,
host=os.getenv("HOST", "0.0.0.0"),
port=os.getenv("PORT", "8050")
)
@mcp.tool()
async def save_memory(ctx: Context, text: str) -> str:
"""Save information to your long-term memory.
This tool is designed to store any type of information that might be useful in the future.
The content will be processed and indexed for later retrieval through semantic search.
Args:
ctx: The MCP server provided context which includes the Mem0 client
text: The content to store in memory, including any relevant details and context
"""
try:
mem0_client = ctx.request_context.lifespan_context.mem0_client
messages = [{"role": "user", "content": text}]
mem0_client.add(messages, user_id=DEFAULT_USER_ID)
return f"Successfully saved memory: {text[:100]}..." if len(text) > 100 else f"Successfully saved memory: {text}"
except Exception as e:
return f"Error saving memory: {str(e)}"
@mcp.tool()
async def get_all_memories(ctx: Context) -> str:
"""Get all stored memories for the user.
Call this tool when you need complete context of all previously memories.
Args:
ctx: The MCP server provided context which includes the Mem0 client
Returns a JSON formatted list of all stored memories, including when they were created
and their content. Results are paginated with a default of 50 items per page.
"""
try:
mem0_client = ctx.request_context.lifespan_context.mem0_client
memories = mem0_client.get_all(user_id=DEFAULT_USER_ID)
if isinstance(memories, dict) and "results" in memories:
flattened_memories = [memory["memory"] for memory in memories["results"]]
else:
flattened_memories = memories
return json.dumps(flattened_memories, indent=2)
except Exception as e:
return f"Error retrieving memories: {str(e)}"
@mcp.tool()
async def search_memories(ctx: Context, query: str, limit: int = 3) -> str:
"""Search memories using semantic search.
This tool should be called to find relevant information from your memory. Results are ranked by relevance.
Always search your memories before making decisions to ensure you leverage your existing knowledge.
Args:
ctx: The MCP server provided context which includes the Mem0 client
query: Search query string describing what you're looking for. Can be natural language.
limit: Maximum number of results to return (default: 3)
"""
try:
mem0_client = ctx.request_context.lifespan_context.mem0_client
memories = mem0_client.search(query, user_id=DEFAULT_USER_ID, limit=limit)
if isinstance(memories, dict) and "results" in memories:
flattened_memories = [memory["memory"] for memory in memories["results"]]
else:
flattened_memories = memories
return json.dumps(flattened_memories, indent=2)
except Exception as e:
return f"Error searching memories: {str(e)}"
async def main():
transport = os.getenv("TRANSPORT", "sse")
if transport == 'sse':
# Run the MCP server with sse transport
await mcp.run_sse_async()
else:
# Run the MCP server with stdio transport
await mcp.run_stdio_async()
if __name__ == "__main__":
asyncio.run(main())