MCP Reddit Server
by wllcnm
- src
import asyncio
import logging
import os
from typing import List
import praw
from mcp.server import Server
from mcp.types import Tool, TextContent
from mcp.server.stdio import stdio_server
# 日志配置
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("reddit_mcp_server")
class RedditMCPServer:
def __init__(self):
self.app = Server("reddit_mcp_server")
self.setup_tools()
def get_reddit_client(self):
# 从环境变量获取凭证
client_id = os.environ.get("REDDIT_CLIENT_ID")
client_secret = os.environ.get("REDDIT_CLIENT_SECRET")
user_agent = os.environ.get("REDDIT_USER_AGENT", "MCP-Reddit/1.0")
if not all([client_id, client_secret]):
raise ValueError("Missing Reddit API credentials in environment variables")
return praw.Reddit(
client_id=client_id,
client_secret=client_secret,
user_agent=user_agent
)
def setup_tools(self):
@self.app.list_tools()
async def list_tools() -> List[Tool]:
return [
Tool(
name="search_subreddit",
description="Search for posts in a specific subreddit",
inputSchema={
"type": "object",
"properties": {
"subreddit": {
"type": "string",
"description": "Name of the subreddit to search"
},
"query": {
"type": "string",
"description": "Search query"
},
"limit": {
"type": "integer",
"description": "Maximum number of results to return",
"default": 5
}
},
"required": ["subreddit", "query"]
}
),
Tool(
name="get_post_details",
description="Get detailed information about a specific Reddit post",
inputSchema={
"type": "object",
"properties": {
"post_id": {
"type": "string",
"description": "ID of the Reddit post"
},
"comment_limit": {
"type": "integer",
"description": "Maximum number of comments to fetch",
"default": 10
}
},
"required": ["post_id"]
}
),
Tool(
name="get_subreddit_hot",
description="Get hot posts from a specific subreddit",
inputSchema={
"type": "object",
"properties": {
"subreddit": {
"type": "string",
"description": "Name of the subreddit"
},
"limit": {
"type": "integer",
"description": "Maximum number of posts to return",
"default": 5
}
},
"required": ["subreddit"]
}
)
]
@self.app.call_tool()
async def call_tool(name: str, arguments: dict) -> List[TextContent]:
reddit = self.get_reddit_client()
if name == "search_subreddit":
subreddit = arguments["subreddit"]
query = arguments["query"]
limit = arguments.get("limit", 5)
try:
subreddit = reddit.subreddit(subreddit)
search_results = subreddit.search(query, limit=limit)
results = []
for post in search_results:
results.append(f"Title: {post.title}\n"
f"ID: {post.id}\n"
f"Score: {post.score}\n"
f"URL: {post.url}\n"
f"Created: {post.created_utc}\n"
f"---")
return [TextContent(type="text", text="\n\n".join(results))]
except Exception as e:
return [TextContent(type="text", text=f"Error searching subreddit: {str(e)}")]
elif name == "get_post_details":
post_id = arguments["post_id"]
comment_limit = arguments.get("comment_limit", 10)
try:
post = reddit.submission(id=post_id)
post.comments.replace_more(limit=0)
post_details = [
f"Title: {post.title}",
f"Author: {post.author}",
f"Score: {post.score}",
f"Content: {post.selftext if post.selftext else '[No text content]'}",
f"URL: {post.url}",
"\nTop Comments:"
]
for comment in post.comments[:comment_limit]:
post_details.append(
f"\nComment by {comment.author} (Score: {comment.score}):"
f"\n{comment.body}"
f"\n---"
)
return [TextContent(type="text", text="\n".join(post_details))]
except Exception as e:
return [TextContent(type="text", text=f"Error getting post details: {str(e)}")]
elif name == "get_subreddit_hot":
subreddit = arguments["subreddit"]
limit = arguments.get("limit", 5)
try:
subreddit = reddit.subreddit(subreddit)
hot_posts = subreddit.hot(limit=limit)
results = []
for post in hot_posts:
results.append(f"Title: {post.title}\n"
f"ID: {post.id}\n"
f"Score: {post.score}\n"
f"URL: {post.url}\n"
f"Created: {post.created_utc}\n"
f"---")
return [TextContent(type="text", text="\n\n".join(results))]
except Exception as e:
return [TextContent(type="text", text=f"Error getting hot posts: {str(e)}")]
else:
return [TextContent(type="text", text=f"Unknown tool: {name}")]
async def run(self):
logger.info("Starting Reddit MCP server...")
async with stdio_server() as (read_stream, write_stream):
try:
await self.app.run(
read_stream,
write_stream,
self.app.create_initialization_options()
)
except Exception as e:
logger.error(f"Server error: {str(e)}", exc_info=True)
raise
def main():
server = RedditMCPServer()
asyncio.run(server.run())
if __name__ == "__main__":
main()