import os
from pathlib import Path
import pytest
from mcp import types
from mcp.shared.exceptions import McpError
from mcp.types import ErrorData
os.environ.setdefault("TELEGRAM_API_ID", "12345")
os.environ.setdefault("TELEGRAM_API_HASH", "dummy_hash")
import main
class _DummySession:
def __init__(self, roots):
self._roots = roots
async def list_roots(self):
return types.ListRootsResult(roots=self._roots)
class _DummyContext:
def __init__(self, roots):
self.session = _DummySession(roots)
class _FailingSession:
def __init__(self, error):
self._error = error
async def list_roots(self):
raise self._error
class _FailingContext:
def __init__(self, error):
self.session = _FailingSession(error)
class _MissingRootsSession:
pass
class _MissingRootsContext:
def __init__(self):
self.session = _MissingRootsSession()
@pytest.mark.asyncio
async def test_readable_relative_path_resolves_inside_first_server_root(tmp_path, monkeypatch):
root = (tmp_path / "root").resolve()
root.mkdir(parents=True)
target = root / "document.txt"
target.write_text("ok", encoding="utf-8")
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [root])
resolved, error = await main._resolve_readable_file_path(
raw_path="document.txt",
ctx=None,
tool_name="send_file",
)
assert error is None
assert resolved == target.resolve()
@pytest.mark.asyncio
async def test_readable_path_rejects_traversal(tmp_path, monkeypatch):
root = (tmp_path / "root").resolve()
root.mkdir(parents=True)
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [root])
resolved, error = await main._resolve_readable_file_path(
raw_path="../etc/passwd",
ctx=None,
tool_name="send_file",
)
assert resolved is None
assert error == "Path traversal is not allowed."
@pytest.mark.asyncio
async def test_readable_path_rejects_outside_root(tmp_path, monkeypatch):
root = (tmp_path / "root").resolve()
outside_root = (tmp_path / "outside").resolve()
root.mkdir(parents=True)
outside_root.mkdir(parents=True)
outside_file = outside_root / "outside.txt"
outside_file.write_text("no", encoding="utf-8")
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [root])
resolved, error = await main._resolve_readable_file_path(
raw_path=str(outside_file),
ctx=None,
tool_name="send_file",
)
assert resolved is None
assert error == "Path is outside allowed roots."
@pytest.mark.asyncio
async def test_client_roots_replace_server_allowlist(tmp_path, monkeypatch):
server_root = (tmp_path / "server_root").resolve()
client_root = (tmp_path / "client_root").resolve()
server_root.mkdir(parents=True)
client_root.mkdir(parents=True)
(server_root / "server.txt").write_text("server", encoding="utf-8")
client_file = client_root / "client.txt"
client_file.write_text("client", encoding="utf-8")
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
ctx = _DummyContext([types.Root(uri=client_root.as_uri())])
roots = await main._get_effective_allowed_roots(ctx)
assert roots == [client_root]
resolved, error = await main._resolve_readable_file_path(
raw_path="client.txt",
ctx=ctx,
tool_name="send_file",
)
assert error is None
assert resolved == client_file.resolve()
@pytest.mark.asyncio
async def test_empty_client_roots_disable_file_tools(tmp_path, monkeypatch):
server_root = (tmp_path / "server_root").resolve()
server_root.mkdir(parents=True)
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
ctx = _DummyContext([])
roots = await main._get_effective_allowed_roots(ctx)
assert roots == []
resolved, error = await main._resolve_readable_file_path(
raw_path="server.txt",
ctx=ctx,
tool_name="send_file",
)
assert resolved is None
assert error is not None
assert "empty MCP Roots list" in error
assert "deny-all" in error
@pytest.mark.asyncio
async def test_mcp_method_not_found_falls_back_to_server_allowlist(tmp_path, monkeypatch):
server_root = (tmp_path / "server_root").resolve()
server_root.mkdir(parents=True)
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
ctx = _FailingContext(McpError(ErrorData(code=-32601, message="Method not found")))
roots = await main._get_effective_allowed_roots(ctx)
assert roots == [server_root]
@pytest.mark.asyncio
async def test_missing_list_roots_method_falls_back_to_server_allowlist(tmp_path, monkeypatch):
server_root = (tmp_path / "server_root").resolve()
server_root.mkdir(parents=True)
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
ctx = _MissingRootsContext()
roots = await main._get_effective_allowed_roots(ctx)
assert roots == [server_root]
@pytest.mark.asyncio
async def test_unexpected_roots_error_disables_file_path_tools(tmp_path, monkeypatch):
server_root = (tmp_path / "server_root").resolve()
server_root.mkdir(parents=True)
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [server_root])
ctx = _FailingContext(RuntimeError("transport failure"))
roots = await main._get_effective_allowed_roots(ctx)
assert roots == []
resolved, error = await main._resolve_readable_file_path(
raw_path="anything.txt",
ctx=ctx,
tool_name="send_file",
)
assert resolved is None
assert error is not None
assert "disabled" in error
@pytest.mark.asyncio
async def test_writable_default_path_uses_downloads_subdir(tmp_path, monkeypatch):
root = (tmp_path / "root").resolve()
root.mkdir(parents=True)
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [root])
resolved, error = await main._resolve_writable_file_path(
raw_path=None,
default_filename="example.bin",
ctx=None,
tool_name="download_media",
)
assert error is None
assert resolved == (root / "downloads" / "example.bin").resolve()
assert resolved.parent.exists()
@pytest.mark.asyncio
async def test_extension_allowlist_is_enforced_for_sticker(tmp_path, monkeypatch):
root = (tmp_path / "root").resolve()
root.mkdir(parents=True)
file_path = root / "sticker.txt"
file_path.write_text("bad", encoding="utf-8")
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [root])
resolved, error = await main._resolve_readable_file_path(
raw_path=str(file_path),
ctx=None,
tool_name="send_sticker",
)
assert resolved is None
assert error is not None
assert "extension is not allowed" in error
@pytest.mark.asyncio
async def test_file_tools_disabled_without_any_roots(monkeypatch):
monkeypatch.setattr(main, "SERVER_ALLOWED_ROOTS", [])
resolved, error = await main._resolve_readable_file_path(
raw_path="anything.txt",
ctx=None,
tool_name="send_file",
)
assert resolved is None
assert error is not None
assert "disabled" in error