Telegram MCP Server
by sparfenyuk
- src
- mcp_telegram
from __future__ import annotations
import logging
import sys
import typing as t
from functools import singledispatch
from mcp.types import (
EmbeddedResource,
ImageContent,
TextContent,
Tool,
)
from pydantic import BaseModel, ConfigDict
from telethon import TelegramClient, custom, functions, types # type: ignore[import-untyped]
from .telegram import create_client
logger = logging.getLogger(__name__)
# How to add a new tool:
#
# 1. Create a new class that inherits from ToolArgs
# ```python
# class NewTool(ToolArgs):
# """Description of the new tool."""
# pass
# ```
# Attributes of the class will be used as arguments for the tool.
# The class docstring will be used as the tool description.
#
# 2. Implement the tool_runner function for the new class
# ```python
# @tool_runner.register
# async def new_tool(args: NewTool) -> t.Sequence[TextContent | ImageContent | EmbeddedResource]:
# pass
# ```
# The function should return a sequence of TextContent, ImageContent or EmbeddedResource.
# The function should be async and accept a single argument of the new class.
#
# 3. Done! Restart the client and the new tool should be available.
class ToolArgs(BaseModel):
model_config = ConfigDict()
@singledispatch
async def tool_runner(
args, # noqa: ANN001
) -> t.Sequence[TextContent | ImageContent | EmbeddedResource]:
raise NotImplementedError(f"Unsupported type: {type(args)}")
def tool_description(args: type[ToolArgs]) -> Tool:
return Tool(
name=args.__name__,
description=args.__doc__,
inputSchema=args.model_json_schema(),
)
def tool_args(tool: Tool, *args, **kwargs) -> ToolArgs: # noqa: ANN002, ANN003
return sys.modules[__name__].__dict__[tool.name](*args, **kwargs)
### ListDialogs ###
class ListDialogs(ToolArgs):
"""List available dialogs, chats and channels."""
unread: bool = False
archived: bool = False
ignore_pinned: bool = False
@tool_runner.register
async def list_dialogs(
args: ListDialogs,
) -> t.Sequence[TextContent | ImageContent | EmbeddedResource]:
client: TelegramClient
logger.info("method[ListDialogs] args[%s]", args)
response: list[TextContent] = []
async with create_client() as client:
dialog: custom.dialog.Dialog
async for dialog in client.iter_dialogs(archived=args.archived, ignore_pinned=args.ignore_pinned):
if args.unread and dialog.unread_count == 0:
continue
msg = (
f"name='{dialog.name}' id={dialog.id} "
f"unread={dialog.unread_count} mentions={dialog.unread_mentions_count}"
)
response.append(TextContent(type="text", text=msg))
return response
### ListMessages ###
class ListMessages(ToolArgs):
"""
List messages in a given dialog, chat or channel. The messages are listed in order from newest to oldest.
If `unread` is set to `True`, only unread messages will be listed. Once a message is read, it will not be
listed again.
If `limit` is set, only the last `limit` messages will be listed. If `unread` is set, the limit will be
the minimum between the unread messages and the limit.
"""
dialog_id: int
unread: bool = False
limit: int = 100
@tool_runner.register
async def list_messages(
args: ListMessages,
) -> t.Sequence[TextContent | ImageContent | EmbeddedResource]:
client: TelegramClient
logger.info("method[ListMessages] args[%s]", args)
response: list[TextContent] = []
async with create_client() as client:
result = await client(functions.messages.GetPeerDialogsRequest(peers=[args.dialog_id]))
if not result:
raise ValueError(f"Channel not found: {args.dialog_id}")
if not isinstance(result, types.messages.PeerDialogs):
raise TypeError(f"Unexpected result: {type(result)}")
for dialog in result.dialogs:
logger.debug("dialog: %s", dialog)
for message in result.messages:
logger.debug("message: %s", message)
iter_messages_args: dict[str, t.Any] = {
"entity": args.dialog_id,
"reverse": False,
}
if args.unread:
iter_messages_args["limit"] = min(dialog.unread_count, args.limit)
else:
iter_messages_args["limit"] = args.limit
logger.debug("iter_messages_args: %s", iter_messages_args)
async for message in client.iter_messages(**iter_messages_args):
logger.debug("message: %s", type(message))
if isinstance(message, custom.Message) and message.text:
logger.debug("message: %s", message.text)
response.append(TextContent(type="text", text=message.text))
return response