MCP Server Reddit

from enum import Enum import json from typing import Sequence import redditwarp.SYNC from mcp.server import Server from mcp.server.stdio import stdio_server from mcp.types import Tool, TextContent, ImageContent, EmbeddedResource from mcp.shared.exceptions import McpError from pydantic import BaseModel class PostType(str, Enum): LINK = "link" TEXT = "text" GALLERY = "gallery" UNKNOWN = "unknown" class RedditTools(str, Enum): GET_FRONTPAGE_POSTS = "get_frontpage_posts" GET_SUBREDDIT_INFO = "get_subreddit_info" GET_SUBREDDIT_HOT_POSTS = "get_subreddit_hot_posts" GET_SUBREDDIT_NEW_POSTS = "get_subreddit_new_posts" GET_SUBREDDIT_TOP_POSTS = "get_subreddit_top_posts" GET_SUBREDDIT_RISING_POSTS = "get_subreddit_rising_posts" GET_POST_CONTENT = "get_post_content" GET_POST_COMMENTS = "get_post_comments" class SubredditInfo(BaseModel): name: str subscriber_count: int description: str | None class Post(BaseModel): id: str title: str author: str score: int subreddit: str url: str created_at: str comment_count: int post_type: PostType content: str | None class Comment(BaseModel): id: str author: str body: str score: int replies: list['Comment'] = [] class Moderator(BaseModel): name: str class PostDetail(BaseModel): post: Post comments: list[Comment] class RedditServer: def __init__(self): self.client = redditwarp.SYNC.Client() def _get_post_type(self, submission) -> PostType: """Helper method to determine post type""" if isinstance(submission, redditwarp.models.submission_SYNC.LinkPost): return PostType.LINK elif isinstance(submission, redditwarp.models.submission_SYNC.TextPost): return PostType.TEXT elif isinstance(submission, redditwarp.models.submission_SYNC.GalleryPost): return PostType.GALLERY return PostType.UNKNOWN # The type can actually be determined by submission.post_hint # - self for text # - image for image # - hosted:video for video def _get_post_content(self, submission) -> str | None: """Helper method to extract post content based on type""" if isinstance(submission, redditwarp.models.submission_SYNC.LinkPost): return submission.permalink elif isinstance(submission, redditwarp.models.submission_SYNC.TextPost): return submission.body elif isinstance(submission, redditwarp.models.submission_SYNC.GalleryPost): return str(submission.gallery_link) return None def _build_post(self, submission) -> Post: """Helper method to build Post object from submission""" return Post( id=submission.id36, title=submission.title, author=submission.author_display_name or '[deleted]', score=submission.score, subreddit=submission.subreddit.name, url=submission.permalink, created_at=submission.created_at.astimezone().isoformat(), comment_count=submission.comment_count, post_type=self._get_post_type(submission), content=self._get_post_content(submission) ) def get_frontpage_posts(self, limit: int = 10) -> list[Post]: """Get hot posts from Reddit frontpage""" posts = [] for subm in self.client.p.front.pull.hot(limit): posts.append(self._build_post(subm)) return posts def get_subreddit_info(self, subreddit_name: str) -> SubredditInfo: """Get information about a subreddit""" subr = self.client.p.subreddit.fetch_by_name(subreddit_name) return SubredditInfo( name=subr.name, subscriber_count=subr.subscriber_count, description=subr.public_description ) def _build_comment_tree(self, node, depth: int = 3) -> Comment | None: """Helper method to recursively build comment tree""" if depth <= 0 or not node: return None comment = node.value replies = [] for child in node.children: child_comment = self._build_comment_tree(child, depth - 1) if child_comment: replies.append(child_comment) return Comment( id=comment.id36, author=comment.author_display_name or '[deleted]', body=comment.body, score=comment.score, replies=replies ) def get_subreddit_hot_posts(self, subreddit_name: str, limit: int = 10) -> list[Post]: """Get hot posts from a specific subreddit""" posts = [] for subm in self.client.p.subreddit.pull.hot(subreddit_name, limit): posts.append(self._build_post(subm)) return posts def get_subreddit_new_posts(self, subreddit_name: str, limit: int = 10) -> list[Post]: """Get new posts from a specific subreddit""" posts = [] for subm in self.client.p.subreddit.pull.new(subreddit_name, limit): posts.append(self._build_post(subm)) return posts def get_subreddit_top_posts(self, subreddit_name: str, limit: int = 10, time: str = '') -> list[Post]: """Get top posts from a specific subreddit""" posts = [] for subm in self.client.p.subreddit.pull.top(subreddit_name, limit, time=time): posts.append(self._build_post(subm)) return posts def get_subreddit_rising_posts(self, subreddit_name: str, limit: int = 10) -> list[Post]: """Get rising posts from a specific subreddit""" posts = [] for subm in self.client.p.subreddit.pull.rising(subreddit_name, limit): posts.append(self._build_post(subm)) return posts def get_post_content(self, post_id: str, comment_limit: int = 10, comment_depth: int = 3) -> PostDetail: """Get detailed content of a specific post including comments""" submission = self.client.p.submission.fetch(post_id) post = self._build_post(submission) # Fetch comments comments = self.get_post_comments(post_id, comment_limit) return PostDetail(post=post, comments=comments) def get_post_comments(self, post_id: str, limit: int = 10) -> list[Comment]: """Get comments from a post""" comments = [] tree_node = self.client.p.comment_tree.fetch(post_id, sort='top', limit=limit) for node in tree_node.children: comment = self._build_comment_tree(node) if comment: comments.append(comment) return comments async def serve() -> None: server = Server("mcp-reddit") reddit_server = RedditServer() @server.list_tools() async def list_tools() -> list[Tool]: """List available Reddit tools.""" return [ Tool( name=RedditTools.GET_FRONTPAGE_POSTS.value, description="Get hot posts from Reddit frontpage", inputSchema={ "type": "object", "properties": { "limit": { "type": "integer", "description": "Number of posts to return (default: 10)", "default": 10, "minimum": 1, "maximum": 100 } } } ), Tool( name=RedditTools.GET_SUBREDDIT_INFO.value, description="Get information about a subreddit", inputSchema={ "type": "object", "properties": { "subreddit_name": { "type": "string", "description": "Name of the subreddit (e.g. 'Python', 'news')", } }, "required": ["subreddit_name"] } ), Tool( name=RedditTools.GET_SUBREDDIT_HOT_POSTS.value, description="Get hot posts from a specific subreddit", inputSchema={ "type": "object", "properties": { "subreddit_name": { "type": "string", "description": "Name of the subreddit (e.g. 'Python', 'news')", }, "limit": { "type": "integer", "description": "Number of posts to return (default: 10)", "default": 10, "minimum": 1, "maximum": 100 } }, "required": ["subreddit_name"] } ), Tool( name=RedditTools.GET_SUBREDDIT_NEW_POSTS.value, description="Get new posts from a specific subreddit", inputSchema={ "type": "object", "properties": { "subreddit_name": { "type": "string", "description": "Name of the subreddit (e.g. 'Python', 'news')", }, "limit": { "type": "integer", "description": "Number of posts to return (default: 10)", "default": 10, "minimum": 1, "maximum": 100 } }, "required": ["subreddit_name"] } ), Tool( name=RedditTools.GET_SUBREDDIT_TOP_POSTS.value, description="Get top posts from a specific subreddit", inputSchema={ "type": "object", "properties": { "subreddit_name": { "type": "string", "description": "Name of the subreddit (e.g. 'Python', 'news')", }, "limit": { "type": "integer", "description": "Number of posts to return (default: 10)", "default": 10, "minimum": 1, "maximum": 100 }, "time": { "type": "string", "description": "Time filter for top posts (e.g. 'hour', 'day', 'week', 'month', 'year', 'all')", "default": "", "enum": ["", "hour", "day", "week", "month", "year", "all"] } }, "required": ["subreddit_name"] } ), Tool( name=RedditTools.GET_SUBREDDIT_RISING_POSTS.value, description="Get rising posts from a specific subreddit", inputSchema={ "type": "object", "properties": { "subreddit_name": { "type": "string", "description": "Name of the subreddit (e.g. 'Python', 'news')", }, "limit": { "type": "integer", "description": "Number of posts to return (default: 10)", "default": 10, "minimum": 1, "maximum": 100 } }, "required": ["subreddit_name"] } ), Tool( name=RedditTools.GET_POST_CONTENT.value, description="Get detailed content of a specific post", inputSchema={ "type": "object", "properties": { "post_id": { "type": "string", "description": "ID of the post", }, "comment_limit": { "type": "integer", "description": "Number of top-level comments to return (default: 10)", "default": 10, "minimum": 1, "maximum": 100 }, "comment_depth": { "type": "integer", "description": "Maximum depth of comment tree (default: 3)", "default": 3, "minimum": 1, "maximum": 10 } }, "required": ["post_id"] } ), Tool( name=RedditTools.GET_POST_COMMENTS.value, description="Get comments from a post", inputSchema={ "type": "object", "properties": { "post_id": { "type": "string", "description": "ID of the post", }, "limit": { "type": "integer", "description": "Number of comments to return (default: 10)", "default": 10, "minimum": 1, "maximum": 100 } }, "required": ["post_id"] } ), ] @server.call_tool() async def call_tool( name: str, arguments: dict ) -> Sequence[TextContent | ImageContent | EmbeddedResource]: """Handle tool calls for Reddit API.""" try: match name: case RedditTools.GET_FRONTPAGE_POSTS.value: limit = arguments.get("limit", 10) result = reddit_server.get_frontpage_posts(limit) case RedditTools.GET_SUBREDDIT_INFO.value: subreddit_name = arguments.get("subreddit_name") if not subreddit_name: raise ValueError("Missing required argument: subreddit_name") result = reddit_server.get_subreddit_info(subreddit_name) case RedditTools.GET_SUBREDDIT_HOT_POSTS.value: subreddit_name = arguments.get("subreddit_name") if not subreddit_name: raise ValueError("Missing required argument: subreddit_name") limit = arguments.get("limit", 10) result = reddit_server.get_subreddit_hot_posts(subreddit_name, limit) case RedditTools.GET_SUBREDDIT_NEW_POSTS.value: subreddit_name = arguments.get("subreddit_name") if not subreddit_name: raise ValueError("Missing required argument: subreddit_name") limit = arguments.get("limit", 10) result = reddit_server.get_subreddit_new_posts(subreddit_name, limit) case RedditTools.GET_SUBREDDIT_TOP_POSTS.value: subreddit_name = arguments.get("subreddit_name") if not subreddit_name: raise ValueError("Missing required argument: subreddit_name") limit = arguments.get("limit", 10) time = arguments.get("time", "") result = reddit_server.get_subreddit_top_posts(subreddit_name, limit, time) case RedditTools.GET_SUBREDDIT_RISING_POSTS.value: subreddit_name = arguments.get("subreddit_name") if not subreddit_name: raise ValueError("Missing required argument: subreddit_name") limit = arguments.get("limit", 10) result = reddit_server.get_subreddit_rising_posts(subreddit_name, limit) case RedditTools.GET_POST_CONTENT.value: post_id = arguments.get("post_id") if not post_id: raise ValueError("Missing required argument: post_id") comment_limit = arguments.get("comment_limit", 10) comment_depth = arguments.get("comment_depth", 3) result = reddit_server.get_post_content(post_id, comment_limit, comment_depth) case RedditTools.GET_POST_COMMENTS.value: post_id = arguments.get("post_id") if not post_id: raise ValueError("Missing required argument: post_id") limit = arguments.get("limit", 10) result = reddit_server.get_post_comments(post_id, limit) case _: raise ValueError(f"Unknown tool: {name}") return [ TextContent(type="text", text=json.dumps(result, default=lambda x: x.model_dump(), indent=2)) ] except Exception as e: raise ValueError(f"Error processing mcp-server-reddit query: {str(e)}") options = server.create_initialization_options() async with stdio_server() as (read_stream, write_stream): await server.run(read_stream, write_stream, options)