#!/usr/bin/env python3
"""
PyPI Package MCP Server
A Model Context Protocol server that enables AI assistants to fetch, explore,
and analyze source code from any Python package on PyPI.
Uses FastMCP from the official MCP SDK for simplified server creation.
"""
import io
import logging
import os
import shutil
import tarfile
import tempfile
import zipfile
from pathlib import Path
from typing import Any
import requests
from mcp.server.fastmcp import FastMCP
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
CODE_EXTENSIONS = {'.py', '.pyx', '.pyi', '.toml', '.yaml', '.yml', '.json', '.md', '.txt', '.cfg', '.ini'}
SKIP_DIRECTORIES = {'.git', '.svn', '.hg', '__pycache__', '.pytest_cache', '.tox', '.eggs', '.venv', 'venv', 'node_modules'}
MAX_CODE_FILES = 20
PYPI_JSON_API_URL = "https://pypi.org/pypi"
# Initialize temporary directory
TEMP_DIR = Path(tempfile.gettempdir()) / "pypi_mcp_cache"
TEMP_DIR.mkdir(exist_ok=True)
# Create FastMCP server instance
mcp = FastMCP("pypi-package-mcp")
# Helper functions
def _fetch_package_info(package_name: str, version: str = "") -> dict[str, Any]:
"""Fetch package metadata from PyPI."""
try:
if version:
url = f"{PYPI_JSON_API_URL}/{package_name}/{version}/json"
else:
url = f"{PYPI_JSON_API_URL}/{package_name}/json"
response = requests.get(url, timeout=10)
response.raise_for_status()
return response.json()
except requests.exceptions.HTTPError as e:
raise ValueError(f"Package not found: {package_name} (HTTP {e.response.status_code})")
except Exception as e:
raise ValueError(f"Failed to fetch package info: {str(e)}")
def _get_source_url(package_data: dict[str, Any]) -> str | None:
"""Extract source URL from package data."""
urls = package_data.get("urls", [])
# Prefer sdist (source distribution)
for url_info in urls:
if url_info.get("packagetype") == "sdist":
return url_info.get("url")
# Fall back to any other format
for url_info in urls:
if url_info.get("url"):
return url_info.get("url")
return None
def _get_common_prefix(paths: list[str]) -> str:
"""Get common prefix from a list of paths."""
if not paths:
return ""
first_parts = paths[0].split("/")
for i, part in enumerate(first_parts):
for path in paths[1:]:
if not path.startswith("/".join(first_parts[: i + 1])):
return "/".join(first_parts[:i])
return "/".join(first_parts)
def _validate_and_resolve_path(base_path: Path, user_path: str) -> Path:
"""
Validate and resolve a user-provided path to prevent traversal attacks.
Args:
base_path: The base directory path
user_path: The user-provided relative path
Returns:
Resolved absolute path if valid
Raises:
ValueError: If path traversal or other security issues detected
"""
# Prevent absolute paths or patterns
if user_path.startswith('/') or user_path.startswith('\\'):
raise ValueError("Absolute paths are not allowed")
# Remove traversal patterns
clean_path = user_path.replace('\\', '/')
if '..' in clean_path or clean_path.startswith('/'):
raise ValueError("Path traversal detected: '..' sequences not allowed")
# Resolve the full path
full_path = (base_path / clean_path).resolve()
# Verify it's still within base_path
try:
full_path.relative_to(base_path.resolve())
except ValueError:
raise ValueError("Access denied: Path is outside allowed directory")
return full_path
def _validate_package_name(package_name: str) -> None:
"""
Validate package name to prevent injection attacks.
Args:
package_name: The package name to validate
Raises:
ValueError: If package name contains invalid characters
"""
import re
# PyPI package names: lowercase letters, numbers, hyphens, underscores, dots
if not re.match(r'^[a-z0-9]([a-z0-9._-]*[a-z0-9])?$', package_name.lower()):
raise ValueError(f"Invalid package name format: {package_name}")
# Check for null bytes and newlines
if '\0' in package_name or '\n' in package_name or '\r' in package_name:
raise ValueError("Package name contains prohibited characters")
def _download_and_extract(download_url: str, package_name: str) -> Path:
"""Download and extract a package."""
sanitized_name = package_name.replace("-", "_").replace(".", "_")
extract_path = TEMP_DIR / sanitized_name
# Clean up existing directory
if extract_path.exists():
shutil.rmtree(extract_path)
extract_path.mkdir(parents=True, exist_ok=True)
try:
logger.info(f"Downloading {download_url}")
response = requests.get(download_url, timeout=30)
response.raise_for_status()
# Determine file type and extract
if download_url.endswith(".tar.gz"):
with tarfile.open(fileobj=io.BytesIO(response.content), mode="r:gz") as tar:
members = tar.getmembers()
common_prefix = _get_common_prefix([m.name for m in members])
for member in members:
if common_prefix:
member.name = member.name[len(common_prefix):].lstrip("/")
tar.extract(member, extract_path)
elif download_url.endswith(".zip"):
with zipfile.ZipFile(io.BytesIO(response.content)) as zf:
members = zf.namelist()
common_prefix = _get_common_prefix(members)
for member in members:
if common_prefix:
new_name = member[len(common_prefix):].lstrip("/")
else:
new_name = member
if new_name:
zf.extract(member, extract_path)
src = extract_path / member
dst = extract_path / new_name
if src != dst:
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(src), str(dst))
else:
raise ValueError(f"Unsupported archive format: {download_url}")
return extract_path
except Exception as e:
raise ValueError(f"Failed to download/extract: {str(e)}")
def _find_code_files(directory: Path) -> list[Path]:
"""Find all code files in a directory."""
code_files = []
for item in directory.rglob("*"):
if item.is_file() and item.suffix in CODE_EXTENSIONS:
if not any(skip_dir in item.parts for skip_dir in SKIP_DIRECTORIES):
code_files.append(item)
return sorted(code_files)
def _get_all_files(directory: Path) -> list[Path]:
"""Get all files in a directory."""
all_files = []
for item in directory.rglob("*"):
if item.is_file() and not any(skip_dir in item.parts for skip_dir in SKIP_DIRECTORIES):
all_files.append(item)
return sorted(all_files)
# FastMCP tool definitions
@mcp.tool()
def search_pypi_packages(query: str, limit: int = 20) -> dict[str, Any]:
"""Search for packages on PyPI.
Args:
query: Search query string
limit: Maximum number of results to return (default: 20)
Returns:
Dictionary containing search results with package names, versions, and descriptions
"""
try:
# Try PyPI JSON API search
url = f"https://pypi.org/pypi"
response = requests.get(f"{url}?action=json&query={query}&submit=search", timeout=10)
if response.status_code == 200:
data = response.json()
results = []
for pkg in data.get("query", {}).get("results", [])[:limit]:
results.append({
"name": pkg.get("name"),
"version": pkg.get("version"),
"summary": pkg.get("summary"),
"url": f"https://pypi.org/project/{pkg.get('name')}/",
})
if results:
return {
"query": query,
"total": len(results),
"results": results,
}
except Exception:
pass
# Fallback: return popular packages filtered by query
popular_packages = [
{"name": "requests", "version": "latest", "summary": "A simple, yet powerful HTTP library for Python"},
{"name": "django", "version": "latest", "summary": "The Web framework for perfectionists"},
{"name": "flask", "version": "latest", "summary": "A simple framework for building web applications"},
{"name": "numpy", "version": "latest", "summary": "Numerical computing with Python"},
{"name": "pandas", "version": "latest", "summary": "Data analysis and manipulation library"},
{"name": "pytest", "version": "latest", "summary": "Python testing framework"},
{"name": "black", "version": "latest", "summary": "Python code formatter"},
{"name": "pydantic", "version": "latest", "summary": "Data validation and settings management"},
]
query_lower = query.lower()
filtered = [
p | {"url": f"https://pypi.org/project/{p['name']}/"}
for p in popular_packages
if query_lower in p["name"].lower() or query_lower in p["summary"].lower()
]
return {
"query": query,
"total": len(filtered),
"results": filtered[:limit],
"note": "Using fallback results due to API limitations"
}
@mcp.tool()
def get_package_info(package_name: str, version: str = "") -> dict[str, Any]:
"""Get detailed information about a Python package.
Args:
package_name: Name of the package on PyPI
version: Specific version (optional, defaults to latest)
Returns:
Dictionary containing package metadata including version, author, license, dependencies
"""
# Validate package name
_validate_package_name(package_name)
try:
package_data = _fetch_package_info(package_name, version)
except ValueError as e:
raise ValueError(f"Failed to fetch package: {str(e)}")
package_info = package_data.get("info", {})
return {
"name": package_info.get("name"),
"version": package_info.get("version"),
"summary": package_info.get("summary"),
"author": package_info.get("author"),
"author_email": package_info.get("author_email"),
"license": package_info.get("license"),
"home_page": package_info.get("home_page"),
"requires_dist": package_info.get("requires_dist", []),
"classifiers": package_info.get("classifiers", []),
}
@mcp.tool()
def list_package_files(package_name: str, version: str = "") -> dict[str, Any]:
"""List all files in a Python package.
Args:
package_name: Name of the package on PyPI
version: Specific version (optional, defaults to latest)
Returns:
Dictionary containing list of all files in the package
"""
package_data = _fetch_package_info(package_name, version)
download_url = _get_source_url(package_data)
if not download_url:
raise ValueError("No source distribution found")
extracted_path = _download_and_extract(download_url, package_name)
all_files = _get_all_files(extracted_path)
package_info = package_data.get("info", {})
file_list = [str(f.relative_to(extracted_path)) for f in all_files]
file_list.sort()
return {
"package": package_name,
"version": package_info.get("version"),
"total_files": len(file_list),
"files": file_list,
}
@mcp.tool()
def get_package_code(package_name: str, version: str = "", file_path: str = "") -> dict[str, Any]:
"""Get source code from a Python package.
Args:
package_name: Name of the package on PyPI
version: Specific version (optional, defaults to latest)
file_path: Specific file to retrieve (optional, retrieves all code files if not specified)
Returns:
Dictionary containing source code content. If file_path specified, returns single file.
Otherwise returns multiple code files (up to 20).
"""
# Validate package name
_validate_package_name(package_name)
try:
package_data = _fetch_package_info(package_name, version)
except ValueError as e:
raise ValueError(f"Failed to fetch package info: {str(e)}")
download_url = _get_source_url(package_data)
if not download_url:
raise ValueError("No source distribution found for this package")
extracted_path = _download_and_extract(download_url, package_name)
package_info = package_data.get("info", {})
if file_path:
# Validate and resolve the file path to prevent traversal
try:
full_path = _validate_and_resolve_path(extracted_path, file_path)
except ValueError as e:
raise ValueError(f"Invalid file path: {str(e)}")
if not full_path.exists():
raise FileNotFoundError(f"File not found: {file_path}")
content = full_path.read_text(encoding="utf-8", errors="ignore")
return {
"package": package_name,
"version": package_info.get("version"),
"file": file_path,
"content": content,
}
else:
code_files = _find_code_files(extracted_path)
files_data = []
for code_file in code_files[:MAX_CODE_FILES]:
relative_path = code_file.relative_to(extracted_path)
content = code_file.read_text(encoding="utf-8", errors="ignore")
files_data.append({
"path": str(relative_path),
"content": content,
})
return {
"package": package_name,
"version": package_info.get("version"),
"description": package_info.get("summary"),
"files": files_data,
"total_files": len(code_files),
"showing": len(files_data),
}
def cleanup() -> None:
"""Clean up temporary files."""
try:
if TEMP_DIR.exists():
shutil.rmtree(TEMP_DIR)
except Exception as e:
logger.error(f"Failed to cleanup temp directory: {e}")
def main() -> None:
"""Main entry point."""
import signal
def signal_handler(sig: int, frame: Any) -> None:
logger.info("Received signal, cleaning up...")
cleanup()
os._exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
transport_mode = os.getenv("TRANSPORT_MODE", "stdio").lower()
try:
logger.info(f"Starting PyPI Package MCP server in {transport_mode} mode")
if transport_mode == "http":
port = int(os.getenv("PORT", "3001"))
mcp.settings.host = "0.0.0.0"
mcp.settings.port = port
mcp.run(transport="streamable-http")
else:
mcp.run(transport="stdio")
except Exception as e:
logger.error(f"Server error: {e}")
cleanup()
raise
if __name__ == "__main__":
main()