app.pyโข16 kB
import os
import asyncio
import discord
from discord.ext import commands
from fastapi import FastAPI
from fastmcp import FastMCP
from fastmcp.server.auth import BearerAuthProvider
from fastmcp.server.auth.providers.bearer import RSAKeyPair
from fastmcp.server.dependencies import get_access_token
import uvicorn
from dotenv import load_dotenv
import logging
from datetime import datetime
import jwt
# Load environment variables
load_dotenv()
# --------------------------------
# Setup Audit Logging
# --------------------------------
logging.basicConfig(
filename="audit.log",
level=logging.INFO,
format="%(message)s"
)
def log_audit(tool: str, user_id: str, args: dict, success: bool, error: str = None):
entry = {
"timestamp": datetime.utcnow().isoformat(),
"tool": tool,
"user_id": user_id,
"args": args,
"success": success
}
if error:
entry["error"] = error
logging.info(entry)
# Helper function to create a temporary bot connection
async def create_temp_bot(token: str) -> commands.Bot:
"""Create a temporary bot instance for a single operation"""
intents = discord.Intents.default()
intents.message_content = True
bot = commands.Bot(command_prefix="!", intents=intents)
# Start the bot and wait for it to be ready
await bot.start(token)
return bot
async def cleanup_bot(bot: commands.Bot):
"""Clean up a temporary bot instance"""
if bot and not bot.is_closed():
await bot.close()
def get_discord_token() -> str:
"""Extract Discord token from authenticated JWT"""
access_token = get_access_token()
# Try to get discord_token from additional_claims if available
if hasattr(access_token, 'additional_claims') and access_token.additional_claims:
discord_token = access_token.additional_claims.get("discord_token")
if discord_token:
return discord_token
# Fallback: decode the JWT directly to get claims
try:
decoded = jwt.decode(access_token.token, options={"verify_signature": False})
discord_token = decoded.get("discord_token")
if discord_token:
return discord_token
except:
pass
raise ValueError("No Discord token found in authentication")
# --------------------------------
# Setup Authentication
# --------------------------------
# Generate RSA key pair for development (use external IdP for production)
key_pair = RSAKeyPair.generate()
# Configure Bearer Token authentication
auth = BearerAuthProvider(
public_key=key_pair.public_key,
issuer="https://discord-mcp-server.local",
audience="discord-mcp-server"
)
# --------------------------------
# Set up FastAPI + FastMCP with Authentication
# --------------------------------
app = FastAPI()
mcp = FastMCP(name="Discord MCP Server", auth=auth)
# --------------------------------
# Token Generation Utility
# --------------------------------
def generate_access_token(discord_token: str, user_id: str = "discord-user") -> str:
"""Generate a JWT access token containing the Discord token"""
return key_pair.create_token(
subject=user_id,
issuer="https://discord-mcp-server.local",
audience="discord-mcp-server",
additional_claims={
"discord_token": discord_token,
"permissions": ["discord:read", "discord:write", "discord:moderate"]
},
expires_in_seconds=3600 # 1 hour
)
# --------------------------------
# MCP Tool: send_message
# --------------------------------
@mcp.tool
async def send_message(channel_id: str, message: str) -> dict:
"""
Send a message to a specific Discord channel.
Args:
channel_id (str): The Discord channel ID to send the message to
message (str): The message content to send
Returns:
dict: Success response with message details or error message
"""
bot = None
try:
# Get Discord token from authenticated JWT
access_token = get_access_token()
discord_token = get_discord_token()
tool_name = "send_message"
args = {"channel_id": channel_id, "message": message}
bot = await create_temp_bot(discord_token)
channel = bot.get_channel(int(channel_id)) or await bot.fetch_channel(int(channel_id))
sent = await channel.send(message)
result = {
"status": "success",
"message_id": sent.id,
"channel_id": sent.channel.id,
"timestamp": sent.created_at.isoformat()
}
log_audit(tool_name, access_token.client_id, args, success=True)
return result
except ValueError as e:
return {"error": str(e)}
except Exception as e:
log_audit(tool_name, access_token.client_id if 'access_token' in locals() else "unknown", args if 'args' in locals() else {}, success=False, error=str(e))
return {"error": str(e)}
finally:
if bot:
await cleanup_bot(bot)
@mcp.tool
async def get_messages(channel_id: str, limit: int = 10) -> dict:
"""
Fetches the last `limit` messages from a Discord channel.
Args:
channel_id (str): The channel to fetch messages from.
limit (int): Number of messages to fetch (default: 10).
Returns:
dict: A list of message dictionaries or an error message.
"""
bot = None
try:
# Get Discord token from authenticated JWT
access_token = get_access_token()
discord_token = get_discord_token()
tool_name = "get_messages"
args = {"channel_id": channel_id, "limit": limit}
bot = await create_temp_bot(discord_token)
channel = bot.get_channel(int(channel_id)) or await bot.fetch_channel(int(channel_id))
# Check if it's a text channel
if not isinstance(channel, (discord.TextChannel, discord.Thread)):
return {"error": "Invalid channel type. Must be a text or thread channel."}
# Retrieve messages (returns async iterator)
messages = await channel.history(limit=limit).flatten()
# Build structured output
result = {
"messages": [
{
"id": str(msg.id),
"author": str(msg.author),
"content": msg.content,
"timestamp": msg.created_at.isoformat()
}
for msg in messages
]
}
log_audit(tool_name, access_token.client_id, args, success=True)
return result
except discord.Forbidden:
log_audit(tool_name, access_token.client_id if 'access_token' in locals() else "unknown", args, success=False, error="Missing permission to read messages in that channel.")
return {"error": "Missing permission to read messages in that channel."}
except discord.HTTPException as e:
log_audit(tool_name, access_token.client_id if 'access_token' in locals() else "unknown", args, success=False, error=f"Discord API error: {str(e)}")
return {"error": f"Discord API error: {str(e)}"}
except Exception as e:
log_audit(tool_name, access_token.client_id if 'access_token' in locals() else "unknown", args, success=False, error=f"Unexpected error: {str(e)}")
return {"error": f"Unexpected error: {str(e)}"}
finally:
if bot:
await cleanup_bot(bot)
@mcp.tool
async def get_channel_info(channel_id: str) -> dict:
"""
Retrieves metadata about a given Discord channel.
Args:
channel_id (str): The channel ID to query.
Returns:
dict: Channel metadata or error message.
"""
bot = None
try:
# Get Discord token from authenticated JWT
access_token = get_access_token()
discord_token = get_discord_token()
tool_name = "get_channel_info"
args = {"channel_id": channel_id}
bot = await create_temp_bot(discord_token)
channel = bot.get_channel(int(channel_id)) or await bot.fetch_channel(int(channel_id))
# Determine type
channel_type = (
"text" if isinstance(channel, discord.TextChannel)
else "voice" if isinstance(channel, discord.VoiceChannel)
else "thread" if isinstance(channel, discord.Thread)
else "category" if isinstance(channel, discord.CategoryChannel)
else "unknown"
)
# Build metadata
info = {
"id": str(channel.id),
"name": channel.name,
"type": channel_type,
"guild_id": str(channel.guild.id),
"guild_name": channel.guild.name,
"position": channel.position,
"created_at": channel.created_at.isoformat(),
}
# Optional fields (text-only)
if isinstance(channel, discord.TextChannel):
info.update({
"topic": channel.topic,
"nsfw": channel.is_nsfw(),
"slowmode_delay": channel.slowmode_delay
})
log_audit(tool_name, access_token.client_id, args, success=True)
return info
except discord.NotFound:
log_audit(tool_name, access_token.client_id if 'access_token' in locals() else "unknown", args, success=False, error="Channel not found.")
return {"error": "Channel not found."}
except discord.Forbidden:
log_audit(tool_name, access_token.client_id if 'access_token' in locals() else "unknown", args, success=False, error="Bot doesn't have permission to access this channel.")
return {"error": "Bot doesn't have permission to access this channel."}
except discord.HTTPException as e:
log_audit(tool_name, access_token.client_id if 'access_token' in locals() else "unknown", args, success=False, error=f"Discord API error: {str(e)}")
return {"error": f"Discord API error: {str(e)}"}
except Exception as e:
log_audit(tool_name, access_token.client_id if 'access_token' in locals() else "unknown", args, success=False, error=f"Unexpected error: {str(e)}")
return {"error": f"Unexpected error: {str(e)}"}
finally:
if bot:
await cleanup_bot(bot)
@mcp.tool
async def search_messages(channel_id: str, query: str, limit: int = 20) -> dict:
"""
Searches recent messages in a Discord channel containing a keyword.
Args:
channel_id (str): ID of the Discord channel.
query (str): Keyword to search in message content.
limit (int): Max messages to return (default: 20).
Returns:
dict: List of matched messages or error.
"""
bot = None
try:
# Get Discord token from authenticated JWT
access_token = get_access_token()
discord_token = get_discord_token()
tool_name = "search_messages"
args = {"channel_id": channel_id, "query": query, "limit": limit}
bot = await create_temp_bot(discord_token)
channel = bot.get_channel(int(channel_id)) or await bot.fetch_channel(int(channel_id))
if not isinstance(channel, (discord.TextChannel, discord.Thread)):
return {"error": "Invalid channel type. Must be a text or thread channel."}
history = await channel.history(limit=100).flatten() # Fetch more to search through
matches = [
{
"id": str(msg.id),
"author": str(msg.author),
"content": msg.content,
"timestamp": msg.created_at.isoformat()
}
for msg in history
if query.lower() in msg.content.lower()
]
log_audit(tool_name, access_token.client_id, args, success=True)
return {"matches": matches[:limit]}
except discord.Forbidden:
log_audit(tool_name, access_token.client_id if 'access_token' in locals() else "unknown", args, success=False, error="Bot lacks permission to read messages in this channel.")
return {"error": "Bot lacks permission to read messages in this channel."}
except discord.HTTPException as e:
log_audit(tool_name, access_token.client_id if 'access_token' in locals() else "unknown", args, success=False, error=f"Discord API error: {str(e)}")
return {"error": f"Discord API error: {str(e)}"}
except Exception as e:
log_audit(tool_name, access_token.client_id if 'access_token' in locals() else "unknown", args, success=False, error=f"Unexpected error: {str(e)}")
return {"error": f"Unexpected error: {str(e)}"}
finally:
if bot:
await cleanup_bot(bot)
@mcp.tool
async def moderate_content(channel_id: str, message_ids: list[str]) -> dict:
"""
Deletes specific messages in a channel.
Args:
channel_id (str): The ID of the channel.
message_ids (list[str]): List of message IDs to delete.
Returns:
dict: Status of deleted messages or error.
"""
bot = None
try:
# Get Discord token from authenticated JWT
access_token = get_access_token()
discord_token = get_discord_token()
tool_name = "moderate_content"
args = {"channel_id": channel_id, "message_ids": message_ids}
deleted = []
failed = []
bot = await create_temp_bot(discord_token)
channel = bot.get_channel(int(channel_id)) or await bot.fetch_channel(int(channel_id))
for mid in message_ids:
try:
msg = await channel.fetch_message(int(mid))
await msg.delete()
deleted.append(mid)
except discord.NotFound:
failed.append({"id": mid, "error": "Not found"})
except discord.Forbidden:
failed.append({"id": mid, "error": "Permission denied"})
except discord.HTTPException as e:
failed.append({"id": mid, "error": str(e)})
log_audit(tool_name, access_token.client_id, args, success=True)
return {
"status": "partial" if failed else "success",
"deleted": deleted,
"failed": failed
}
except discord.Forbidden:
log_audit(tool_name, access_token.client_id if 'access_token' in locals() else "unknown", args, success=False, error="Bot lacks permission to manage messages in this channel.")
return {"error": "Bot lacks permission to manage messages in this channel."}
except discord.HTTPException as e:
log_audit(tool_name, access_token.client_id if 'access_token' in locals() else "unknown", args, success=False, error=f"Discord API error: {str(e)}")
return {"error": f"Discord API error: {str(e)}"}
except Exception as e:
log_audit(tool_name, access_token.client_id if 'access_token' in locals() else "unknown", args, success=False, error=f"Unexpected error: {str(e)}")
return {"error": f"Unexpected error: {str(e)}"}
finally:
if bot:
await cleanup_bot(bot)
# --------------------------------
# Start MCP Server
# --------------------------------
def start():
"""Start the MCP server with authentication"""
print("๐ Starting Authenticated Discord MCP Server...")
print("๐ Authentication: Bearer Token (JWT)")
print("๐ Available tools:")
print(" - send_message: Send messages to Discord channels")
print(" - get_messages: Fetch recent messages from channels")
print(" - get_channel_info: Get channel metadata")
print(" - search_messages: Search for messages containing keywords")
print(" - moderate_content: Delete specific messages")
print("๐ Server will run on http://localhost:8000")
print("\n๐ To generate an access token, use:")
print(" python -c \"from app import generate_access_token; print(generate_access_token('YOUR_DISCORD_BOT_TOKEN'))\"")
uvicorn.run(app, host="0.0.0.0", port=8000)
if __name__ == "__main__":
start()